mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-02 01:27:14 +00:00
feat(model): Support deploy rerank model (#1522)
This commit is contained in:
@@ -1,9 +1,10 @@
|
||||
"""Rerank module for RAG retriever."""
|
||||
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.core import Chunk, RerankEmbeddings
|
||||
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
|
||||
from dbgpt.util.i18n_utils import _
|
||||
|
||||
@@ -39,6 +40,24 @@ class Ranker(ABC):
|
||||
List[Chunk]
|
||||
"""
|
||||
|
||||
async def arank(
|
||||
self, candidates_with_scores: List[Chunk], query: Optional[str] = None
|
||||
) -> List[Chunk]:
|
||||
"""Return top k chunks after ranker.
|
||||
|
||||
Rank algorithm implementation return topk documents by candidates
|
||||
similarity score
|
||||
|
||||
Args:
|
||||
candidates_with_scores: List[Tuple]
|
||||
query: Optional[str]
|
||||
Return:
|
||||
List[Chunk]
|
||||
"""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, self.rank, candidates_with_scores, query
|
||||
)
|
||||
|
||||
def _filter(self, candidates_with_scores: List) -> List[Chunk]:
|
||||
"""Filter duplicate candidates documents."""
|
||||
candidates_with_scores = sorted(
|
||||
@@ -52,6 +71,18 @@ class Ranker(ABC):
|
||||
visited_docs.add(candidate_chunk.content)
|
||||
return new_candidates
|
||||
|
||||
def _rerank_with_scores(
|
||||
self, candidates_with_scores: List[Chunk], rank_scores: List[float]
|
||||
) -> List[Chunk]:
|
||||
"""Rerank candidates with scores."""
|
||||
for candidate, score in zip(candidates_with_scores, rank_scores):
|
||||
candidate.score = float(score)
|
||||
|
||||
new_candidates_with_scores = sorted(
|
||||
candidates_with_scores, key=lambda x: x.score, reverse=True
|
||||
)
|
||||
return new_candidates_with_scores
|
||||
|
||||
|
||||
@register_resource(
|
||||
_("Default Ranker"),
|
||||
@@ -225,3 +256,59 @@ class CrossEncoderRanker(Ranker):
|
||||
candidates_with_scores, key=lambda x: x.score, reverse=True
|
||||
)
|
||||
return new_candidates_with_scores[: self.topk]
|
||||
|
||||
|
||||
class RerankEmbeddingsRanker(Ranker):
|
||||
"""Rerank Embeddings Ranker."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rerank_embeddings: RerankEmbeddings,
|
||||
topk: int = 4,
|
||||
rank_fn: Optional[RANK_FUNC] = None,
|
||||
):
|
||||
"""Rerank Embeddings rank algorithm implementation."""
|
||||
self._model = rerank_embeddings
|
||||
super().__init__(topk, rank_fn)
|
||||
|
||||
def rank(
|
||||
self, candidates_with_scores: List[Chunk], query: Optional[str] = None
|
||||
) -> List[Chunk]:
|
||||
"""Rerank Embeddings rank algorithm implementation.
|
||||
|
||||
Args:
|
||||
candidates_with_scores: List[Chunk], candidates with scores
|
||||
query: Optional[str], query text
|
||||
Returns:
|
||||
List[Chunk], reranked candidates
|
||||
"""
|
||||
if not candidates_with_scores or not query:
|
||||
return candidates_with_scores
|
||||
|
||||
contents = [candidate.content for candidate in candidates_with_scores]
|
||||
rank_scores = self._model.predict(query, contents)
|
||||
new_candidates_with_scores = self._rerank_with_scores(
|
||||
candidates_with_scores, rank_scores
|
||||
)
|
||||
return new_candidates_with_scores[: self.topk]
|
||||
|
||||
async def arank(
|
||||
self, candidates_with_scores: List[Chunk], query: Optional[str] = None
|
||||
) -> List[Chunk]:
|
||||
"""Rerank Embeddings rank algorithm implementation.
|
||||
|
||||
Args:
|
||||
candidates_with_scores: List[Chunk], candidates with scores
|
||||
query: Optional[str], query text
|
||||
Returns:
|
||||
List[Chunk], reranked candidates
|
||||
"""
|
||||
if not candidates_with_scores or not query:
|
||||
return candidates_with_scores
|
||||
|
||||
contents = [candidate.content for candidate in candidates_with_scores]
|
||||
rank_scores = await self._model.apredict(query, contents)
|
||||
new_candidates_with_scores = self._rerank_with_scores(
|
||||
candidates_with_scores, rank_scores
|
||||
)
|
||||
return new_candidates_with_scores[: self.topk]
|
||||
|
Reference in New Issue
Block a user