mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-28 12:51:54 +00:00
feat(RAG):add MRR and HitRate retriever metrics. (#1456)
This commit is contained in:
parent
6520367623
commit
71529975d8
Binary file not shown.
Before Width: | Height: | Size: 118 KiB After Width: | Height: | Size: 118 KiB |
@ -67,6 +67,81 @@ class RetrieverSimilarityMetric(RetrieverEvaluationMetric):
|
||||
)
|
||||
|
||||
|
||||
class RetrieverMRRMetric(RetrieverEvaluationMetric):
|
||||
"""Retriever Mean Reciprocal Rank metric.
|
||||
|
||||
For each query, MRR evaluates the system’s accuracy by looking at the rank of the
|
||||
highest-placed relevant document. Specifically, it’s the average of the reciprocals
|
||||
of these ranks across all the queries. So, if the first relevant document is the
|
||||
top result, the reciprocal rank is 1; if it’s second, the reciprocal rank is 1/2,
|
||||
and so on.
|
||||
"""
|
||||
|
||||
def sync_compute(
|
||||
self,
|
||||
prediction: Optional[List[str]] = None,
|
||||
contexts: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseEvaluationResult:
|
||||
"""Compute MRR metric.
|
||||
|
||||
Args:
|
||||
prediction(Optional[List[str]]): The retrieved chunks from the retriever.
|
||||
contexts(Optional[List[str]]): The contexts from dataset.
|
||||
Returns:
|
||||
BaseEvaluationResult: The evaluation result.
|
||||
The score is the reciprocal rank of the first relevant chunk.
|
||||
"""
|
||||
if not prediction or not contexts:
|
||||
return BaseEvaluationResult(
|
||||
prediction=prediction,
|
||||
contexts=contexts,
|
||||
score=0.0,
|
||||
)
|
||||
for i, retrieved_chunk in enumerate(prediction):
|
||||
if retrieved_chunk in contexts:
|
||||
return BaseEvaluationResult(
|
||||
score=1.0 / (i + 1),
|
||||
)
|
||||
return BaseEvaluationResult(
|
||||
score=0.0,
|
||||
)
|
||||
|
||||
|
||||
class RetrieverHitRateMetric(RetrieverEvaluationMetric):
|
||||
"""Retriever Hit Rate metric.
|
||||
|
||||
Hit rate calculates the fraction of queries where the correct answer is found
|
||||
within the top-k retrieved documents. In simpler terms, it’s about how often our
|
||||
system gets it right within the top few guesses.
|
||||
"""
|
||||
|
||||
def sync_compute(
|
||||
self,
|
||||
prediction: Optional[List[str]] = None,
|
||||
contexts: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseEvaluationResult:
|
||||
"""Compute HitRate metric.
|
||||
|
||||
Args:
|
||||
prediction(Optional[List[str]]): The retrieved chunks from the retriever.
|
||||
contexts(Optional[List[str]]): The contexts from dataset.
|
||||
Returns:
|
||||
BaseEvaluationResult: The evaluation result.
|
||||
"""
|
||||
if not prediction or not contexts:
|
||||
return BaseEvaluationResult(
|
||||
prediction=prediction,
|
||||
contexts=contexts,
|
||||
score=0.0,
|
||||
)
|
||||
is_hit = any(context in prediction for context in contexts)
|
||||
return BaseEvaluationResult(
|
||||
score=1.0 if is_hit else 0.0,
|
||||
)
|
||||
|
||||
|
||||
class RetrieverEvaluator(Evaluator):
|
||||
"""Evaluator for relevancy.
|
||||
|
||||
@ -144,7 +219,7 @@ class RetrieverEvaluator(Evaluator):
|
||||
contexts_key: str = "contexts",
|
||||
prediction_key: str = "prediction",
|
||||
parallel_num: int = 1,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> List[List[EvaluationResult]]:
|
||||
"""Evaluate the dataset."""
|
||||
from dbgpt.core.awel import DAG, IteratorTrigger, MapOperator
|
||||
|
@ -8,6 +8,11 @@ from dbgpt.rag import ChunkParameters
|
||||
from dbgpt.rag.assembler import EmbeddingAssembler
|
||||
from dbgpt.rag.embedding import DefaultEmbeddingFactory
|
||||
from dbgpt.rag.evaluation import RetrieverEvaluator
|
||||
from dbgpt.rag.evaluation.retriever import (
|
||||
RetrieverHitRateMetric,
|
||||
RetrieverMRRMetric,
|
||||
RetrieverSimilarityMetric,
|
||||
)
|
||||
from dbgpt.rag.knowledge import KnowledgeFactory
|
||||
from dbgpt.rag.operators import EmbeddingRetrieverOperator
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
@ -42,7 +47,7 @@ async def main():
|
||||
knowledge = KnowledgeFactory.from_file_path(file_path)
|
||||
embeddings = _create_embeddings()
|
||||
vector_connector = _create_vector_connector(embeddings)
|
||||
chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE")
|
||||
chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_MARKDOWN_HEADER")
|
||||
# get embedding assembler
|
||||
assembler = EmbeddingAssembler.load_from_knowledge(
|
||||
knowledge=knowledge,
|
||||
@ -55,9 +60,14 @@ async def main():
|
||||
{
|
||||
"query": "what is awel talk about",
|
||||
"contexts": [
|
||||
"Through the AWEL API, you can focus on the development"
|
||||
" of business logic for LLMs applications without paying "
|
||||
"attention to cumbersome model and environment details."
|
||||
"# What is AWEL? \n\nAgentic Workflow Expression Language(AWEL) is a "
|
||||
"set of intelligent agent workflow expression language specially "
|
||||
"designed for large model application\ndevelopment. It provides great "
|
||||
"functionality and flexibility. Through the AWEL API, you can focus on "
|
||||
"the development of business logic for LLMs applications\nwithout "
|
||||
"paying attention to cumbersome model and environment details.\n\nAWEL "
|
||||
"adopts a layered API design. AWEL's layered API design architecture is "
|
||||
"shown in the figure below."
|
||||
],
|
||||
},
|
||||
]
|
||||
@ -69,7 +79,12 @@ async def main():
|
||||
"vector_store_connector": vector_connector,
|
||||
},
|
||||
)
|
||||
results = await evaluator.evaluate(dataset)
|
||||
metrics = [
|
||||
RetrieverHitRateMetric(),
|
||||
RetrieverMRRMetric(),
|
||||
RetrieverSimilarityMetric(embeddings=embeddings),
|
||||
]
|
||||
results = await evaluator.evaluate(dataset, metrics)
|
||||
for result in results:
|
||||
for metric in result:
|
||||
print("Metric:", metric.metric_name)
|
||||
|
Loading…
Reference in New Issue
Block a user