feat(RAG):add MRR and HitRate retriever metrics. (#1456)

This commit is contained in:
Aries-ckt 2024-04-25 09:12:45 +08:00 committed by GitHub
parent 6520367623
commit 71529975d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 96 additions and 6 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 118 KiB

After

Width:  |  Height:  |  Size: 118 KiB

View File

@ -67,6 +67,81 @@ class RetrieverSimilarityMetric(RetrieverEvaluationMetric):
) )
class RetrieverMRRMetric(RetrieverEvaluationMetric):
"""Retriever Mean Reciprocal Rank metric.
For each query, MRR evaluates the systems accuracy by looking at the rank of the
highest-placed relevant document. Specifically, its the average of the reciprocals
of these ranks across all the queries. So, if the first relevant document is the
top result, the reciprocal rank is 1; if its second, the reciprocal rank is 1/2,
and so on.
"""
def sync_compute(
self,
prediction: Optional[List[str]] = None,
contexts: Optional[List[str]] = None,
**kwargs: Any,
) -> BaseEvaluationResult:
"""Compute MRR metric.
Args:
prediction(Optional[List[str]]): The retrieved chunks from the retriever.
contexts(Optional[List[str]]): The contexts from dataset.
Returns:
BaseEvaluationResult: The evaluation result.
The score is the reciprocal rank of the first relevant chunk.
"""
if not prediction or not contexts:
return BaseEvaluationResult(
prediction=prediction,
contexts=contexts,
score=0.0,
)
for i, retrieved_chunk in enumerate(prediction):
if retrieved_chunk in contexts:
return BaseEvaluationResult(
score=1.0 / (i + 1),
)
return BaseEvaluationResult(
score=0.0,
)
class RetrieverHitRateMetric(RetrieverEvaluationMetric):
"""Retriever Hit Rate metric.
Hit rate calculates the fraction of queries where the correct answer is found
within the top-k retrieved documents. In simpler terms, its about how often our
system gets it right within the top few guesses.
"""
def sync_compute(
self,
prediction: Optional[List[str]] = None,
contexts: Optional[List[str]] = None,
**kwargs: Any,
) -> BaseEvaluationResult:
"""Compute HitRate metric.
Args:
prediction(Optional[List[str]]): The retrieved chunks from the retriever.
contexts(Optional[List[str]]): The contexts from dataset.
Returns:
BaseEvaluationResult: The evaluation result.
"""
if not prediction or not contexts:
return BaseEvaluationResult(
prediction=prediction,
contexts=contexts,
score=0.0,
)
is_hit = any(context in prediction for context in contexts)
return BaseEvaluationResult(
score=1.0 if is_hit else 0.0,
)
class RetrieverEvaluator(Evaluator): class RetrieverEvaluator(Evaluator):
"""Evaluator for relevancy. """Evaluator for relevancy.
@ -144,7 +219,7 @@ class RetrieverEvaluator(Evaluator):
contexts_key: str = "contexts", contexts_key: str = "contexts",
prediction_key: str = "prediction", prediction_key: str = "prediction",
parallel_num: int = 1, parallel_num: int = 1,
**kwargs **kwargs,
) -> List[List[EvaluationResult]]: ) -> List[List[EvaluationResult]]:
"""Evaluate the dataset.""" """Evaluate the dataset."""
from dbgpt.core.awel import DAG, IteratorTrigger, MapOperator from dbgpt.core.awel import DAG, IteratorTrigger, MapOperator

View File

@ -8,6 +8,11 @@ from dbgpt.rag import ChunkParameters
from dbgpt.rag.assembler import EmbeddingAssembler from dbgpt.rag.assembler import EmbeddingAssembler
from dbgpt.rag.embedding import DefaultEmbeddingFactory from dbgpt.rag.embedding import DefaultEmbeddingFactory
from dbgpt.rag.evaluation import RetrieverEvaluator from dbgpt.rag.evaluation import RetrieverEvaluator
from dbgpt.rag.evaluation.retriever import (
RetrieverHitRateMetric,
RetrieverMRRMetric,
RetrieverSimilarityMetric,
)
from dbgpt.rag.knowledge import KnowledgeFactory from dbgpt.rag.knowledge import KnowledgeFactory
from dbgpt.rag.operators import EmbeddingRetrieverOperator from dbgpt.rag.operators import EmbeddingRetrieverOperator
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
@ -42,7 +47,7 @@ async def main():
knowledge = KnowledgeFactory.from_file_path(file_path) knowledge = KnowledgeFactory.from_file_path(file_path)
embeddings = _create_embeddings() embeddings = _create_embeddings()
vector_connector = _create_vector_connector(embeddings) vector_connector = _create_vector_connector(embeddings)
chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE") chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_MARKDOWN_HEADER")
# get embedding assembler # get embedding assembler
assembler = EmbeddingAssembler.load_from_knowledge( assembler = EmbeddingAssembler.load_from_knowledge(
knowledge=knowledge, knowledge=knowledge,
@ -55,9 +60,14 @@ async def main():
{ {
"query": "what is awel talk about", "query": "what is awel talk about",
"contexts": [ "contexts": [
"Through the AWEL API, you can focus on the development" "# What is AWEL? \n\nAgentic Workflow Expression Language(AWEL) is a "
" of business logic for LLMs applications without paying " "set of intelligent agent workflow expression language specially "
"attention to cumbersome model and environment details." "designed for large model application\ndevelopment. It provides great "
"functionality and flexibility. Through the AWEL API, you can focus on "
"the development of business logic for LLMs applications\nwithout "
"paying attention to cumbersome model and environment details.\n\nAWEL "
"adopts a layered API design. AWEL's layered API design architecture is "
"shown in the figure below."
], ],
}, },
] ]
@ -69,7 +79,12 @@ async def main():
"vector_store_connector": vector_connector, "vector_store_connector": vector_connector,
}, },
) )
results = await evaluator.evaluate(dataset) metrics = [
RetrieverHitRateMetric(),
RetrieverMRRMetric(),
RetrieverSimilarityMetric(embeddings=embeddings),
]
results = await evaluator.evaluate(dataset, metrics)
for result in results: for result in results:
for metric in result: for metric in result:
print("Metric:", metric.metric_name) print("Metric:", metric.metric_name)