mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-13 21:21:08 +00:00
feat(rag): Support rag retriever evaluation (#1291)
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
from .datasource import DatasourceRetrieverOperator # noqa: F401
|
||||
from .db_schema import DBSchemaRetrieverOperator # noqa: F401
|
||||
from .embedding import EmbeddingRetrieverOperator # noqa: F401
|
||||
from .evaluation import RetrieverEvaluatorOperator # noqa: F401
|
||||
from .knowledge import KnowledgeOperator # noqa: F401
|
||||
from .rerank import RerankOperator # noqa: F401
|
||||
from .rewrite import QueryRewriteOperator # noqa: F401
|
||||
@@ -16,4 +17,5 @@ __all__ = [
|
||||
"RerankOperator",
|
||||
"QueryRewriteOperator",
|
||||
"SummaryAssemblerOperator",
|
||||
"RetrieverEvaluatorOperator",
|
||||
]
|
||||
|
@@ -1,16 +1,17 @@
|
||||
"""Embedding retriever operator."""
|
||||
|
||||
from functools import reduce
|
||||
from typing import Any, Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from dbgpt.core.interface.operators.retriever import RetrieverOperator
|
||||
from dbgpt.rag.chunk import Chunk
|
||||
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
|
||||
from dbgpt.rag.retriever.rerank import Ranker
|
||||
from dbgpt.rag.retriever.rewrite import QueryRewrite
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
|
||||
class EmbeddingRetrieverOperator(RetrieverOperator[Any, Any]):
|
||||
class EmbeddingRetrieverOperator(RetrieverOperator[Union[str, List[str]], List[Chunk]]):
|
||||
"""The Embedding Retriever Operator."""
|
||||
|
||||
def __init__(
|
||||
@@ -32,7 +33,7 @@ class EmbeddingRetrieverOperator(RetrieverOperator[Any, Any]):
|
||||
rerank=rerank,
|
||||
)
|
||||
|
||||
def retrieve(self, query: Any) -> Any:
|
||||
def retrieve(self, query: Union[str, List[str]]) -> List[Chunk]:
|
||||
"""Retrieve the candidates."""
|
||||
if isinstance(query, str):
|
||||
return self._retriever.retrieve_with_scores(query, self._score_threshold)
|
||||
|
61
dbgpt/rag/operators/evaluation.py
Normal file
61
dbgpt/rag/operators/evaluation.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""Evaluation operators."""
|
||||
import asyncio
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from dbgpt.core.awel import JoinOperator
|
||||
from dbgpt.core.interface.evaluation import EvaluationMetric, EvaluationResult
|
||||
from dbgpt.core.interface.llm import LLMClient
|
||||
|
||||
from ..chunk import Chunk
|
||||
|
||||
|
||||
class RetrieverEvaluatorOperator(JoinOperator[List[EvaluationResult]]):
|
||||
"""Evaluator for retriever."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
evaluation_metrics: List[EvaluationMetric],
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create a new RetrieverEvaluatorOperator."""
|
||||
self.llm_client = llm_client
|
||||
self.evaluation_metrics = evaluation_metrics
|
||||
super().__init__(combine_function=self._do_evaluation, **kwargs)
|
||||
|
||||
async def _do_evaluation(
|
||||
self,
|
||||
query: str,
|
||||
prediction: List[Chunk],
|
||||
contexts: List[str],
|
||||
raw_dataset: Any = None,
|
||||
) -> List[EvaluationResult]:
|
||||
"""Run evaluation.
|
||||
|
||||
Args:
|
||||
query(str): The query string.
|
||||
prediction(List[Chunk]): The retrieved chunks from the retriever.
|
||||
contexts(List[str]): The contexts from dataset.
|
||||
raw_dataset(Any): The raw data(single row) from dataset.
|
||||
"""
|
||||
if isinstance(contexts, str):
|
||||
contexts = [contexts]
|
||||
prediction_strs = [chunk.content for chunk in prediction]
|
||||
tasks = []
|
||||
for metric in self.evaluation_metrics:
|
||||
tasks.append(metric.compute(prediction_strs, contexts))
|
||||
task_results = await asyncio.gather(*tasks)
|
||||
results = []
|
||||
for result, metric in zip(task_results, self.evaluation_metrics):
|
||||
results.append(
|
||||
EvaluationResult(
|
||||
query=query,
|
||||
prediction=prediction,
|
||||
score=result.score,
|
||||
contexts=contexts,
|
||||
passing=result.passing,
|
||||
raw_dataset=raw_dataset,
|
||||
metric_name=metric.name,
|
||||
)
|
||||
)
|
||||
return results
|
Reference in New Issue
Block a user