feat(RAG):add MRR and HitRate retriever metrics. (#1456)

This commit is contained in:
Aries-ckt
2024-04-25 09:12:45 +08:00
committed by GitHub
parent 6520367623
commit 71529975d8
3 changed files with 96 additions and 6 deletions

View File

@@ -67,6 +67,81 @@ class RetrieverSimilarityMetric(RetrieverEvaluationMetric):
)
class RetrieverMRRMetric(RetrieverEvaluationMetric):
"""Retriever Mean Reciprocal Rank metric.
For each query, MRR evaluates the systems accuracy by looking at the rank of the
highest-placed relevant document. Specifically, its the average of the reciprocals
of these ranks across all the queries. So, if the first relevant document is the
top result, the reciprocal rank is 1; if its second, the reciprocal rank is 1/2,
and so on.
"""
def sync_compute(
self,
prediction: Optional[List[str]] = None,
contexts: Optional[List[str]] = None,
**kwargs: Any,
) -> BaseEvaluationResult:
"""Compute MRR metric.
Args:
prediction(Optional[List[str]]): The retrieved chunks from the retriever.
contexts(Optional[List[str]]): The contexts from dataset.
Returns:
BaseEvaluationResult: The evaluation result.
The score is the reciprocal rank of the first relevant chunk.
"""
if not prediction or not contexts:
return BaseEvaluationResult(
prediction=prediction,
contexts=contexts,
score=0.0,
)
for i, retrieved_chunk in enumerate(prediction):
if retrieved_chunk in contexts:
return BaseEvaluationResult(
score=1.0 / (i + 1),
)
return BaseEvaluationResult(
score=0.0,
)
class RetrieverHitRateMetric(RetrieverEvaluationMetric):
"""Retriever Hit Rate metric.
Hit rate calculates the fraction of queries where the correct answer is found
within the top-k retrieved documents. In simpler terms, its about how often our
system gets it right within the top few guesses.
"""
def sync_compute(
self,
prediction: Optional[List[str]] = None,
contexts: Optional[List[str]] = None,
**kwargs: Any,
) -> BaseEvaluationResult:
"""Compute HitRate metric.
Args:
prediction(Optional[List[str]]): The retrieved chunks from the retriever.
contexts(Optional[List[str]]): The contexts from dataset.
Returns:
BaseEvaluationResult: The evaluation result.
"""
if not prediction or not contexts:
return BaseEvaluationResult(
prediction=prediction,
contexts=contexts,
score=0.0,
)
is_hit = any(context in prediction for context in contexts)
return BaseEvaluationResult(
score=1.0 if is_hit else 0.0,
)
class RetrieverEvaluator(Evaluator):
"""Evaluator for relevancy.
@@ -144,7 +219,7 @@ class RetrieverEvaluator(Evaluator):
contexts_key: str = "contexts",
prediction_key: str = "prediction",
parallel_num: int = 1,
**kwargs
**kwargs,
) -> List[List[EvaluationResult]]:
"""Evaluate the dataset."""
from dbgpt.core.awel import DAG, IteratorTrigger, MapOperator