feat(rag): Support rag retriever evaluation (#1291)

This commit is contained in:
Fangyin Cheng
2024-03-14 13:06:57 +08:00
committed by GitHub
parent cd2dcc253c
commit adaa68eb00
34 changed files with 1452 additions and 67 deletions

View File

@@ -3,6 +3,7 @@
from .datasource import DatasourceRetrieverOperator # noqa: F401
from .db_schema import DBSchemaRetrieverOperator # noqa: F401
from .embedding import EmbeddingRetrieverOperator # noqa: F401
from .evaluation import RetrieverEvaluatorOperator # noqa: F401
from .knowledge import KnowledgeOperator # noqa: F401
from .rerank import RerankOperator # noqa: F401
from .rewrite import QueryRewriteOperator # noqa: F401
@@ -16,4 +17,5 @@ __all__ = [
"RerankOperator",
"QueryRewriteOperator",
"SummaryAssemblerOperator",
"RetrieverEvaluatorOperator",
]

View File

@@ -1,16 +1,17 @@
"""Embedding retriever operator."""
from functools import reduce
from typing import Any, Optional
from typing import List, Optional, Union
from dbgpt.core.interface.operators.retriever import RetrieverOperator
from dbgpt.rag.chunk import Chunk
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
from dbgpt.rag.retriever.rerank import Ranker
from dbgpt.rag.retriever.rewrite import QueryRewrite
from dbgpt.storage.vector_store.connector import VectorStoreConnector
class EmbeddingRetrieverOperator(RetrieverOperator[Any, Any]):
class EmbeddingRetrieverOperator(RetrieverOperator[Union[str, List[str]], List[Chunk]]):
"""The Embedding Retriever Operator."""
def __init__(
@@ -32,7 +33,7 @@ class EmbeddingRetrieverOperator(RetrieverOperator[Any, Any]):
rerank=rerank,
)
def retrieve(self, query: Any) -> Any:
def retrieve(self, query: Union[str, List[str]]) -> List[Chunk]:
"""Retrieve the candidates."""
if isinstance(query, str):
return self._retriever.retrieve_with_scores(query, self._score_threshold)

View File

@@ -0,0 +1,61 @@
"""Evaluation operators."""
import asyncio
from typing import Any, List, Optional
from dbgpt.core.awel import JoinOperator
from dbgpt.core.interface.evaluation import EvaluationMetric, EvaluationResult
from dbgpt.core.interface.llm import LLMClient
from ..chunk import Chunk
class RetrieverEvaluatorOperator(JoinOperator[List[EvaluationResult]]):
"""Evaluator for retriever."""
def __init__(
self,
evaluation_metrics: List[EvaluationMetric],
llm_client: Optional[LLMClient] = None,
**kwargs,
):
"""Create a new RetrieverEvaluatorOperator."""
self.llm_client = llm_client
self.evaluation_metrics = evaluation_metrics
super().__init__(combine_function=self._do_evaluation, **kwargs)
async def _do_evaluation(
self,
query: str,
prediction: List[Chunk],
contexts: List[str],
raw_dataset: Any = None,
) -> List[EvaluationResult]:
"""Run evaluation.
Args:
query(str): The query string.
prediction(List[Chunk]): The retrieved chunks from the retriever.
contexts(List[str]): The contexts from dataset.
raw_dataset(Any): The raw data(single row) from dataset.
"""
if isinstance(contexts, str):
contexts = [contexts]
prediction_strs = [chunk.content for chunk in prediction]
tasks = []
for metric in self.evaluation_metrics:
tasks.append(metric.compute(prediction_strs, contexts))
task_results = await asyncio.gather(*tasks)
results = []
for result, metric in zip(task_results, self.evaluation_metrics):
results.append(
EvaluationResult(
query=query,
prediction=prediction,
score=result.score,
contexts=contexts,
passing=result.passing,
raw_dataset=raw_dataset,
metric_name=metric.name,
)
)
return results