DB-GPT/dbgpt/core/interface/evaluation.py
明天 b124ecc10b
feat: (0.6)New UI (#1855)
Co-authored-by: 夏姜 <wenfengjiang.jwf@digital-engine.com>
Co-authored-by: aries_ckt <916701291@qq.com>
Co-authored-by: wb-lh513319 <wb-lh513319@alibaba-inc.com>
Co-authored-by: csunny <cfqsunny@163.com>
2024-08-21 17:37:45 +08:00

312 lines
9.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Evaluation module."""
import asyncio
import string
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Generic,
Iterator,
List,
Optional,
Sequence,
Type,
TypeVar,
Union,
)
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt.util.similarity_util import calculate_cosine_similarity
from .embeddings import Embeddings
from .llm import LLMClient
if TYPE_CHECKING:
from dbgpt.core.awel.task.base import InputSource
QueryType = Union[str, Any]
PredictionType = Union[str, Any]
ContextType = Union[str, Sequence[str], Any]
DatasetType = Union["InputSource", Iterator, AsyncIterator]
EVALUATE_FILE_COL_QUESTION = "query"
EVALUATE_FILE_COL_ANSWER = "factual"
EVALUATE_FILE_COL_PREDICTION = "prediction"
EVALUATE_FILE_COL_PREDICTION_COST = "prediction_cost"
class BaseEvaluationResult(BaseModel):
"""Base evaluation result."""
prediction: Optional[PredictionType] = Field(
None,
description="Prediction data(including the output of LLM, the data from "
"retrieval, etc.)",
)
contexts: Optional[ContextType] = Field(None, description="Context data")
score: Optional[float] = Field(
None, description="Score for the prediction in now metric"
)
passing: Optional[bool] = Field(
True, description="Determine whether the current prediction result is valid"
)
metric_name: Optional[str] = Field(None, description="Name of the metric")
prediction_cost: int = 0
class EvaluationResult(BaseEvaluationResult):
"""Evaluation result.
Output of an BaseEvaluator.
"""
query: Optional[QueryType] = Field(None, description="Query data")
raw_dataset: Optional[Any] = Field(None, description="Raw dataset")
feedback: Optional[str] = Field(None, description="feedback")
Q = TypeVar("Q")
P = TypeVar("P")
C = TypeVar("C")
class EvaluationMetric(ABC, Generic[P, C]):
"""Base class for evaluation metric."""
def __init__(self, **kwargs): # noqa
pass
@classmethod
def name(cls) -> str:
"""Name of the metric."""
return cls.__name__
@classmethod
def describe(cls) -> str:
"""Describe."""
return f"This is an evaluation result index calculation tool, named {cls.name} "
async def compute(
self,
prediction: P,
contexts: Optional[Sequence[C]] = None,
query: Optional[str] = None,
) -> BaseEvaluationResult:
"""Compute the evaluation metric.
Args:
prediction(P): The prediction data.
contexts(Optional[Sequence[C]]): The context data.
query:(Optional[str]) The query text.
Returns:
BaseEvaluationResult: The evaluation result.
"""
return await asyncio.get_running_loop().run_in_executor(
None, self.sync_compute, prediction, contexts, query
)
def sync_compute(
self,
prediction: P,
contexts: Optional[Sequence[C]] = None,
query: Optional[str] = None,
) -> BaseEvaluationResult:
"""Compute the evaluation metric.
Args:
prediction(P): The prediction data.
contexts(Optional[Sequence[C]]): The factual data.
query:(Optional[str]) The query text.
Returns:
BaseEvaluationResult: The evaluation result.
"""
raise NotImplementedError("sync_compute is not implemented")
class FunctionMetric(EvaluationMetric[P, C], Generic[P, C]):
"""Evaluation metric based on a function."""
def __init__(self, **kwargs):
"""Create a FunctionMetric.
Args:
name(str): The name of the metric.
func(Callable[[P, Optional[Sequence[C]]], BaseEvaluationResult]):
The function to use for evaluation.
"""
if "name" not in kwargs:
raise ValueError("Must need param name")
if "func" not in kwargs:
raise ValueError("Must need param func")
self._name = kwargs.get("name", None)
self.func = kwargs.get("func", None)
@property
def name(self) -> str: # type: ignore # noqa
"""Name of the metric."""
return self._name
async def compute(
self,
prediction: P,
context: Optional[Sequence[C]] = None,
query: Optional[str] = None,
) -> BaseEvaluationResult:
"""Compute the evaluation metric."""
return self.func(prediction, context)
class ExactMatchMetric(EvaluationMetric[str, str]):
"""Exact match metric.
Just support string prediction and context.
"""
def __init__(self, **kwargs):
"""Create an ExactMatchMetric."""
self._ignore_case = kwargs.get("ignore_case", False)
self._ignore_punctuation = kwargs.get("ignore_punctuation", False)
async def compute(
self,
prediction: str,
contexts: Optional[Sequence[str]] = None,
query: Optional[str] = None,
) -> BaseEvaluationResult:
"""Compute the evaluation metric."""
if self._ignore_case:
prediction = prediction.lower()
if contexts:
contexts = [c.lower() for c in contexts]
if self._ignore_punctuation:
prediction = prediction.translate(str.maketrans("", "", string.punctuation))
if contexts:
contexts = [
c.translate(str.maketrans("", "", string.punctuation))
for c in contexts
]
score = 0 if not contexts else float(prediction in contexts)
return BaseEvaluationResult(
prediction=prediction,
contexts=contexts,
score=score,
)
class SimilarityMetric(EvaluationMetric[str, str]):
"""Similarity metric.
Calculate the cosine similarity between a prediction and a list of contexts.
"""
def __init__(self, **kwargs):
"""Create a SimilarityMetric with embeddings."""
self._embeddings = kwargs.get("embeddings", None)
if self._embeddings is None or not isinstance(self._embeddings, Embeddings):
raise ValueError("Need embedding service")
def sync_compute(
self,
prediction: str,
contexts: Optional[Sequence[str]] = None,
query: Optional[str] = None,
) -> BaseEvaluationResult:
"""Compute the evaluation metric."""
if not contexts:
return BaseEvaluationResult(
prediction=prediction,
contexts=contexts,
score=0.0,
)
try:
import numpy as np
except ImportError:
raise ImportError("numpy is required for SimilarityMetric")
similarity: np.ndarray = calculate_cosine_similarity(
self._embeddings, prediction, contexts
)
return BaseEvaluationResult(
prediction=prediction,
contexts=contexts,
score=float(similarity.mean()),
)
class Evaluator(ABC):
"""Base Evaluator class."""
def __init__(
self,
llm_client: Optional[LLMClient] = None,
):
"""Create an Evaluator."""
self.llm_client = llm_client
@abstractmethod
async def evaluate(
self,
dataset: DatasetType,
metrics: Optional[List[EvaluationMetric]] = None,
query_key: str = "query",
contexts_key: str = "contexts",
prediction_key: str = "prediction",
parallel_num: int = 1,
**kwargs,
) -> List[List[EvaluationResult]]:
"""Run evaluation with a dataset and metrics.
Args:
dataset(DatasetType): The dataset to evaluate.
metrics(Optional[List[EvaluationMetric]]): The metrics to use for
evaluation.
query_key(str): The key for query in the dataset.
contexts_key(str): The key for contexts in the dataset.
prediction_key(str): The key for prediction in the dataset.
parallel_num(int): The number of parallel tasks.
kwargs: Additional arguments.
Returns:
List[List[EvaluationResult]]: The evaluation results, the length of the
result equals to the length of the dataset. The first element in the
list is the list of evaluation results for metrics.
"""
class MetricManage:
"""MetricManage."""
def __init__(self):
"""Init metricManage."""
self.metrics = defaultdict()
def register_metric(self, cls: Type[EvaluationMetric]):
"""Register metric."""
self.metrics[cls.name] = cls
def get_by_name(self, name: str) -> Type[EvaluationMetric]:
"""Get by name."""
if name not in self.metrics:
raise ValueError(f"Metric:{name} not register!")
return self.metrics[name]
def all_metric_infos(self):
"""Get all metric infos."""
result = []
for name, cls in self.metrics.items():
result.append(
{
"name": name,
"describe": cls.describe,
}
)
return result
metric_mange = MetricManage()