diff --git a/dbgpt/rag/retriever/rerank.py b/dbgpt/rag/retriever/rerank.py index e38d601f7..69240c4b9 100644 --- a/dbgpt/rag/retriever/rerank.py +++ b/dbgpt/rag/retriever/rerank.py @@ -24,7 +24,9 @@ class Ranker(ABC): self.rank_fn = rank_fn @abstractmethod - def rank(self, candidates_with_scores: List) -> List[Chunk]: + def rank( + self, candidates_with_scores: List[Chunk], query: Optional[str] = None + ) -> List[Chunk]: """Return top k chunks after ranker. Rank algorithm implementation return topk documents by candidates @@ -32,10 +34,9 @@ class Ranker(ABC): Args: candidates_with_scores: List[Tuple] - topk: int - + query: Optional[str] Return: - List[Document] + List[Chunk] """ def _filter(self, candidates_with_scores: List) -> List[Chunk]: @@ -77,11 +78,17 @@ class Ranker(ABC): class DefaultRanker(Ranker): """Default Ranker.""" - def __init__(self, topk: int, rank_fn: Optional[RANK_FUNC] = None): + def __init__( + self, + topk: int = 4, + rank_fn: Optional[RANK_FUNC] = None, + ): """Create Default Ranker with topk and rank_fn.""" super().__init__(topk, rank_fn) - def rank(self, candidates_with_scores: List[Chunk]) -> List[Chunk]: + def rank( + self, candidates_with_scores: List[Chunk], query: Optional[str] = None + ) -> List[Chunk]: """Return top k chunks after ranker. Return top k documents by candidates similarity score @@ -105,11 +112,17 @@ class DefaultRanker(Ranker): class RRFRanker(Ranker): """RRF(Reciprocal Rank Fusion) Ranker.""" - def __init__(self, topk: int, rank_fn: Optional[RANK_FUNC] = None): + def __init__( + self, + topk: int = 4, + rank_fn: Optional[RANK_FUNC] = None, + ): """RRF rank algorithm implementation.""" super().__init__(topk, rank_fn) - def rank(self, candidates_with_scores: List[Chunk]) -> List[Chunk]: + def rank( + self, candidates_with_scores: List[Chunk], query: Optional[str] = None + ) -> List[Chunk]: """RRF rank algorithm implementation. This code implements an algorithm called Reciprocal Rank Fusion (RRF), is a @@ -128,3 +141,87 @@ class RRFRanker(Ranker): """ # it will be implemented soon when multi recall is implemented return candidates_with_scores + + +@register_resource( + _("CrossEncoder Rerank"), + "cross_encoder_ranker", + category=ResourceCategory.RAG, + description=_("CrossEncoder ranker."), + parameters=[ + Parameter.build_from( + _("Top k"), + "topk", + int, + description=_("The number of top k documents."), + ), + Parameter.build_from( + _("Rerank Model"), + "model", + str, + description=_("rerank model name, e.g., 'BAAI/bge-reranker-base'."), + ), + Parameter.build_from( + _("device"), + "device", + str, + description=_("device name, e.g., 'cpu'."), + ), + ], +) +class CrossEncoderRanker(Ranker): + """CrossEncoder Ranker.""" + + def __init__( + self, + topk: int = 4, + model: str = "BAAI/bge-reranker-base", + device: str = "cpu", + rank_fn: Optional[RANK_FUNC] = None, + ): + """Cross Encoder rank algorithm implementation. + + Args: + topk: int - The number of top k documents. + model: str - rerank model name, e.g., 'BAAI/bge-reranker-base'. + device: str - device name, e.g., 'cpu'. + rank_fn: Optional[callable] - The rank function. + Refer: https://www.sbert.net/examples/applications/cross-encoder/README.html + """ + try: + from sentence_transformers import CrossEncoder + except ImportError: + raise ImportError( + "please `pip install sentence-transformers`", + ) + self._model = CrossEncoder(model, max_length=512, device=device) + super().__init__(topk, rank_fn) + + def rank( + self, candidates_with_scores: List[Chunk], query: Optional[str] = None + ) -> List[Chunk]: + """Cross Encoder rank algorithm implementation. + + Args: + candidates_with_scores: List[Chunk], candidates with scores + query: Optional[str], query text + Returns: + List[Chunk], reranked candidates + """ + contents = [candidate.content for candidate in candidates_with_scores] + query_content_pairs = [ + [ + query, + content, + ] + for content in contents + ] + rank_scores = self._model.predict(sentences=query_content_pairs) + + for candidate, score in zip(candidates_with_scores, rank_scores): + candidate.score = score + + new_candidates_with_scores = sorted( + candidates_with_scores, key=lambda x: x.score, reverse=True + ) + return new_candidates_with_scores[: self.topk] diff --git a/examples/rag/cross_encoder_rerank_example.py b/examples/rag/cross_encoder_rerank_example.py new file mode 100644 index 000000000..c7e3dbaf7 --- /dev/null +++ b/examples/rag/cross_encoder_rerank_example.py @@ -0,0 +1,70 @@ +"""This example demonstrates how to use the cross-encoder reranker +to rerank the retrieved chunks. +The cross-encoder reranker is a neural network model that takes a query +and a chunk as input and outputs a score that represents the relevance of the chunk +to the query. + +Download pretrained cross-encoder models can be found at https://huggingface.co/models. +Example: + python examples/rag/cross_encoder_rerank_example.py +""" +import asyncio +import os + +from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH, ROOT_PATH +from dbgpt.rag import ChunkParameters +from dbgpt.rag.assembler import EmbeddingAssembler +from dbgpt.rag.embedding import DefaultEmbeddingFactory +from dbgpt.rag.knowledge import KnowledgeFactory +from dbgpt.rag.retriever.rerank import CrossEncoderRanker +from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig +from dbgpt.storage.vector_store.connector import VectorStoreConnector + + +def _create_vector_connector(): + """Create vector connector.""" + print(f"persist_path:{os.path.join(PILOT_PATH, 'data')}") + return VectorStoreConnector.from_default( + "Chroma", + vector_store_config=ChromaVectorConfig( + name="example_cross_encoder_rerank", + persist_path=os.path.join(PILOT_PATH, "data"), + ), + embedding_fn=DefaultEmbeddingFactory( + default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"), + ).create(), + ) + + +async def main(): + file_path = os.path.join(ROOT_PATH, "docs/docs/awel/awel.md") + knowledge = KnowledgeFactory.from_file_path(file_path) + vector_connector = _create_vector_connector() + chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_MARKDOWN_HEADER") + # get embedding assembler + assembler = EmbeddingAssembler.load_from_knowledge( + knowledge=knowledge, + chunk_parameters=chunk_parameters, + vector_store_connector=vector_connector, + ) + assembler.persist() + # get embeddings retriever + retriever = assembler.as_retriever(3) + # create metadata filter + query = "what is awel talk about" + chunks = await retriever.aretrieve_with_scores(query, 0.3) + + print("before rerank results:\n") + for i, chunk in enumerate(chunks): + print(f"----{i+1}.chunk content:{chunk.content}\n score:{chunk.score}") + # cross-encoder rerank + cross_encoder_model = os.path.join(MODEL_PATH, "bge-reranker-base") + rerank = CrossEncoderRanker(topk=3, model=cross_encoder_model) + new_chunks = rerank.rank(chunks, query=query) + print("after cross-encoder rerank results:\n") + for i, chunk in enumerate(new_chunks): + print(f"----{i+1}.chunk content:{chunk.content}\n score:{chunk.score}") + + +if __name__ == "__main__": + asyncio.run(main())