mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-04 10:34:30 +00:00
refactor: RAG Refactor (#985)
Co-authored-by: Aralhi <xiaoping0501@gmail.com> Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
@@ -1,11 +1,13 @@
|
||||
from abc import ABC
|
||||
from typing import List, Tuple, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt.rag.chunk import Chunk
|
||||
|
||||
|
||||
class Ranker(ABC):
|
||||
"""Base Ranker"""
|
||||
|
||||
def __init__(self, topk: int, rank_fn: Optional[callable] = None):
|
||||
def __init__(self, topk: int, rank_fn: Optional[callable] = None) -> None:
|
||||
"""
|
||||
abstract base ranker
|
||||
Args:
|
||||
@@ -15,7 +17,7 @@ class Ranker(ABC):
|
||||
self.topk = topk
|
||||
self.rank_fn = rank_fn
|
||||
|
||||
def rank(self, candidates_with_scores: List, topk: int):
|
||||
def rank(self, candidates_with_scores: List) -> List[Chunk]:
|
||||
"""rank algorithm implementation return topk documents by candidates similarity score
|
||||
Args:
|
||||
candidates_with_scores: List[Tuple]
|
||||
@@ -26,17 +28,17 @@ class Ranker(ABC):
|
||||
|
||||
pass
|
||||
|
||||
def _filter(self, candidates_with_scores: List):
|
||||
def _filter(self, candidates_with_scores: List) -> List[Chunk]:
|
||||
"""filter duplicate candidates documents"""
|
||||
candidates_with_scores = sorted(
|
||||
candidates_with_scores, key=lambda x: x[1], reverse=True
|
||||
candidates_with_scores, key=lambda x: x.score, reverse=True
|
||||
)
|
||||
visited_docs = set()
|
||||
new_candidates = []
|
||||
for candidate_doc, score in candidates_with_scores:
|
||||
if candidate_doc.page_content not in visited_docs:
|
||||
new_candidates.append((candidate_doc, score))
|
||||
visited_docs.add(candidate_doc.page_content)
|
||||
for candidate_chunk in candidates_with_scores:
|
||||
if candidate_chunk.content not in visited_docs:
|
||||
new_candidates.append(candidate_chunk)
|
||||
visited_docs.add(candidate_chunk.content)
|
||||
return new_candidates
|
||||
|
||||
|
||||
@@ -46,7 +48,7 @@ class DefaultRanker(Ranker):
|
||||
def __init__(self, topk: int, rank_fn: Optional[callable] = None):
|
||||
super().__init__(topk, rank_fn)
|
||||
|
||||
def rank(self, candidates_with_scores: List[Tuple]):
|
||||
def rank(self, candidates_with_scores: List[Chunk]) -> List[Chunk]:
|
||||
"""Default rank algorithm implementation
|
||||
return topk documents by candidates similarity score
|
||||
Args:
|
||||
@@ -59,11 +61,9 @@ class DefaultRanker(Ranker):
|
||||
candidates_with_scores = self.rank_fn(candidates_with_scores)
|
||||
else:
|
||||
candidates_with_scores = sorted(
|
||||
candidates_with_scores, key=lambda x: x[1], reverse=True
|
||||
candidates_with_scores, key=lambda x: x.score, reverse=True
|
||||
)
|
||||
return [
|
||||
(candidate_doc, score) for candidate_doc, score in candidates_with_scores
|
||||
][: self.topk]
|
||||
return candidates_with_scores[: self.topk]
|
||||
|
||||
|
||||
class RRFRanker(Ranker):
|
||||
@@ -72,7 +72,7 @@ class RRFRanker(Ranker):
|
||||
def __init__(self, topk: int, rank_fn: Optional[callable] = None):
|
||||
super().__init__(topk, rank_fn)
|
||||
|
||||
def rank(self, candidates_with_scores: List):
|
||||
def rank(self, candidates_with_scores: List[Chunk]) -> List[Chunk]:
|
||||
"""RRF rank algorithm implementation
|
||||
This code implements an algorithm called Reciprocal Rank Fusion (RRF), is a method for combining multiple result sets with different relevance indicators into a single result set. RRF requires no tuning, and the different relevance indicators do not have to be related to each other to achieve high-quality results.
|
||||
RRF uses the following formula to determine the score for ranking each document:
|
||||
|
Reference in New Issue
Block a user