feat:evaluation add query parameter

This commit is contained in:
aries_ckt
2024-08-08 14:46:49 +08:00
parent 5734f1c116
commit ed916ab9f2

View File

@@ -92,30 +92,34 @@ class EvaluationMetric(ABC, Generic[P, C]):
self, self,
prediction: P, prediction: P,
contexts: Optional[Sequence[C]] = None, contexts: Optional[Sequence[C]] = None,
query: Optional[str] = None,
) -> BaseEvaluationResult: ) -> BaseEvaluationResult:
"""Compute the evaluation metric. """Compute the evaluation metric.
Args: Args:
prediction(P): The prediction data. prediction(P): The prediction data.
contexts(Optional[Sequence[C]]): The context data. contexts(Optional[Sequence[C]]): The context data.
query:(Optional[str]) The query text.
Returns: Returns:
BaseEvaluationResult: The evaluation result. BaseEvaluationResult: The evaluation result.
""" """
return await asyncio.get_running_loop().run_in_executor( return await asyncio.get_running_loop().run_in_executor(
None, self.sync_compute, prediction, contexts None, self.sync_compute, prediction, contexts, query
) )
def sync_compute( def sync_compute(
self, self,
prediction: P, prediction: P,
contexts: Optional[Sequence[C]] = None, contexts: Optional[Sequence[C]] = None,
query: Optional[str] = None,
) -> BaseEvaluationResult: ) -> BaseEvaluationResult:
"""Compute the evaluation metric. """Compute the evaluation metric.
Args: Args:
prediction(P): The prediction data. prediction(P): The prediction data.
contexts(Optional[Sequence[C]]): The factual data. contexts(Optional[Sequence[C]]): The factual data.
query:(Optional[str]) The query text.
Returns: Returns:
BaseEvaluationResult: The evaluation result. BaseEvaluationResult: The evaluation result.
@@ -151,6 +155,7 @@ class FunctionMetric(EvaluationMetric[P, C], Generic[P, C]):
self, self,
prediction: P, prediction: P,
context: Optional[Sequence[C]] = None, context: Optional[Sequence[C]] = None,
query: Optional[str] = None,
) -> BaseEvaluationResult: ) -> BaseEvaluationResult:
"""Compute the evaluation metric.""" """Compute the evaluation metric."""
return self.func(prediction, context) return self.func(prediction, context)
@@ -171,6 +176,7 @@ class ExactMatchMetric(EvaluationMetric[str, str]):
self, self,
prediction: str, prediction: str,
contexts: Optional[Sequence[str]] = None, contexts: Optional[Sequence[str]] = None,
query: Optional[str] = None,
) -> BaseEvaluationResult: ) -> BaseEvaluationResult:
"""Compute the evaluation metric.""" """Compute the evaluation metric."""
if self._ignore_case: if self._ignore_case:
@@ -208,6 +214,7 @@ class SimilarityMetric(EvaluationMetric[str, str]):
self, self,
prediction: str, prediction: str,
contexts: Optional[Sequence[str]] = None, contexts: Optional[Sequence[str]] = None,
query: Optional[str] = None,
) -> BaseEvaluationResult: ) -> BaseEvaluationResult:
"""Compute the evaluation metric.""" """Compute the evaluation metric."""
if not contexts: if not contexts: