mirror of
				https://github.com/csunny/DB-GPT.git
				synced 2025-10-25 11:35:41 +00:00 
			
		
		
		
	Co-authored-by: 夏姜 <wenfengjiang.jwf@digital-engine.com> Co-authored-by: aries_ckt <916701291@qq.com> Co-authored-by: wb-lh513319 <wb-lh513319@alibaba-inc.com> Co-authored-by: csunny <cfqsunny@163.com>
		
			
				
	
	
		
			312 lines
		
	
	
		
			9.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			312 lines
		
	
	
		
			9.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """Evaluation module."""
 | ||
| import asyncio
 | ||
| import string
 | ||
| from abc import ABC, abstractmethod
 | ||
| from collections import defaultdict
 | ||
| from typing import (
 | ||
|     TYPE_CHECKING,
 | ||
|     Any,
 | ||
|     AsyncIterator,
 | ||
|     Generic,
 | ||
|     Iterator,
 | ||
|     List,
 | ||
|     Optional,
 | ||
|     Sequence,
 | ||
|     Type,
 | ||
|     TypeVar,
 | ||
|     Union,
 | ||
| )
 | ||
| 
 | ||
| from dbgpt._private.pydantic import BaseModel, Field
 | ||
| from dbgpt.util.similarity_util import calculate_cosine_similarity
 | ||
| 
 | ||
| from .embeddings import Embeddings
 | ||
| from .llm import LLMClient
 | ||
| 
 | ||
| if TYPE_CHECKING:
 | ||
|     from dbgpt.core.awel.task.base import InputSource
 | ||
| 
 | ||
| QueryType = Union[str, Any]
 | ||
| PredictionType = Union[str, Any]
 | ||
| ContextType = Union[str, Sequence[str], Any]
 | ||
| DatasetType = Union["InputSource", Iterator, AsyncIterator]
 | ||
| 
 | ||
| EVALUATE_FILE_COL_QUESTION = "query"
 | ||
| EVALUATE_FILE_COL_ANSWER = "factual"
 | ||
| EVALUATE_FILE_COL_PREDICTION = "prediction"
 | ||
| EVALUATE_FILE_COL_PREDICTION_COST = "prediction_cost"
 | ||
| 
 | ||
| 
 | ||
| class BaseEvaluationResult(BaseModel):
 | ||
|     """Base evaluation result."""
 | ||
| 
 | ||
|     prediction: Optional[PredictionType] = Field(
 | ||
|         None,
 | ||
|         description="Prediction data(including the output of LLM, the data from "
 | ||
|         "retrieval, etc.)",
 | ||
|     )
 | ||
|     contexts: Optional[ContextType] = Field(None, description="Context data")
 | ||
|     score: Optional[float] = Field(
 | ||
|         None, description="Score for the prediction in now metric"
 | ||
|     )
 | ||
|     passing: Optional[bool] = Field(
 | ||
|         True, description="Determine whether the current prediction result is valid"
 | ||
|     )
 | ||
|     metric_name: Optional[str] = Field(None, description="Name of the metric")
 | ||
|     prediction_cost: int = 0
 | ||
| 
 | ||
| 
 | ||
| class EvaluationResult(BaseEvaluationResult):
 | ||
|     """Evaluation result.
 | ||
| 
 | ||
|     Output of an BaseEvaluator.
 | ||
|     """
 | ||
| 
 | ||
|     query: Optional[QueryType] = Field(None, description="Query data")
 | ||
|     raw_dataset: Optional[Any] = Field(None, description="Raw dataset")
 | ||
|     feedback: Optional[str] = Field(None, description="feedback")
 | ||
| 
 | ||
| 
 | ||
| Q = TypeVar("Q")
 | ||
| P = TypeVar("P")
 | ||
| C = TypeVar("C")
 | ||
| 
 | ||
| 
 | ||
| class EvaluationMetric(ABC, Generic[P, C]):
 | ||
|     """Base class for evaluation metric."""
 | ||
| 
 | ||
|     def __init__(self, **kwargs):  # noqa
 | ||
|         pass
 | ||
| 
 | ||
|     @classmethod
 | ||
|     def name(cls) -> str:
 | ||
|         """Name of the metric."""
 | ||
|         return cls.__name__
 | ||
| 
 | ||
|     @classmethod
 | ||
|     def describe(cls) -> str:
 | ||
|         """Describe."""
 | ||
|         return f"This is an evaluation result index calculation tool, named {cls.name} "
 | ||
| 
 | ||
|     async def compute(
 | ||
|         self,
 | ||
|         prediction: P,
 | ||
|         contexts: Optional[Sequence[C]] = None,
 | ||
|         query: Optional[str] = None,
 | ||
|     ) -> BaseEvaluationResult:
 | ||
|         """Compute the evaluation metric.
 | ||
| 
 | ||
|         Args:
 | ||
|             prediction(P): The prediction data.
 | ||
|             contexts(Optional[Sequence[C]]): The context data.
 | ||
|             query:(Optional[str]) The query text.
 | ||
| 
 | ||
|         Returns:
 | ||
|             BaseEvaluationResult: The evaluation result.
 | ||
|         """
 | ||
|         return await asyncio.get_running_loop().run_in_executor(
 | ||
|             None, self.sync_compute, prediction, contexts, query
 | ||
|         )
 | ||
| 
 | ||
|     def sync_compute(
 | ||
|         self,
 | ||
|         prediction: P,
 | ||
|         contexts: Optional[Sequence[C]] = None,
 | ||
|         query: Optional[str] = None,
 | ||
|     ) -> BaseEvaluationResult:
 | ||
|         """Compute the evaluation metric.
 | ||
| 
 | ||
|         Args:
 | ||
|             prediction(P): The prediction data.
 | ||
|             contexts(Optional[Sequence[C]]): The factual data.
 | ||
|             query:(Optional[str]) The query text.
 | ||
| 
 | ||
|         Returns:
 | ||
|             BaseEvaluationResult: The evaluation result.
 | ||
|         """
 | ||
|         raise NotImplementedError("sync_compute is not implemented")
 | ||
| 
 | ||
| 
 | ||
| class FunctionMetric(EvaluationMetric[P, C], Generic[P, C]):
 | ||
|     """Evaluation metric based on a function."""
 | ||
| 
 | ||
|     def __init__(self, **kwargs):
 | ||
|         """Create a FunctionMetric.
 | ||
| 
 | ||
|         Args:
 | ||
|             name(str): The name of the metric.
 | ||
|             func(Callable[[P, Optional[Sequence[C]]], BaseEvaluationResult]):
 | ||
|                 The function to use for evaluation.
 | ||
|         """
 | ||
|         if "name" not in kwargs:
 | ||
|             raise ValueError("Must need param name")
 | ||
| 
 | ||
|         if "func" not in kwargs:
 | ||
|             raise ValueError("Must need param func")
 | ||
|         self._name = kwargs.get("name", None)
 | ||
|         self.func = kwargs.get("func", None)
 | ||
| 
 | ||
|     @property
 | ||
|     def name(self) -> str:  # type: ignore # noqa
 | ||
|         """Name of the metric."""
 | ||
|         return self._name
 | ||
| 
 | ||
|     async def compute(
 | ||
|         self,
 | ||
|         prediction: P,
 | ||
|         context: Optional[Sequence[C]] = None,
 | ||
|         query: Optional[str] = None,
 | ||
|     ) -> BaseEvaluationResult:
 | ||
|         """Compute the evaluation metric."""
 | ||
|         return self.func(prediction, context)
 | ||
| 
 | ||
| 
 | ||
| class ExactMatchMetric(EvaluationMetric[str, str]):
 | ||
|     """Exact match metric.
 | ||
| 
 | ||
|     Just support string prediction and context.
 | ||
|     """
 | ||
| 
 | ||
|     def __init__(self, **kwargs):
 | ||
|         """Create an ExactMatchMetric."""
 | ||
|         self._ignore_case = kwargs.get("ignore_case", False)
 | ||
|         self._ignore_punctuation = kwargs.get("ignore_punctuation", False)
 | ||
| 
 | ||
|     async def compute(
 | ||
|         self,
 | ||
|         prediction: str,
 | ||
|         contexts: Optional[Sequence[str]] = None,
 | ||
|         query: Optional[str] = None,
 | ||
|     ) -> BaseEvaluationResult:
 | ||
|         """Compute the evaluation metric."""
 | ||
|         if self._ignore_case:
 | ||
|             prediction = prediction.lower()
 | ||
|             if contexts:
 | ||
|                 contexts = [c.lower() for c in contexts]
 | ||
|         if self._ignore_punctuation:
 | ||
|             prediction = prediction.translate(str.maketrans("", "", string.punctuation))
 | ||
|             if contexts:
 | ||
|                 contexts = [
 | ||
|                     c.translate(str.maketrans("", "", string.punctuation))
 | ||
|                     for c in contexts
 | ||
|                 ]
 | ||
|         score = 0 if not contexts else float(prediction in contexts)
 | ||
|         return BaseEvaluationResult(
 | ||
|             prediction=prediction,
 | ||
|             contexts=contexts,
 | ||
|             score=score,
 | ||
|         )
 | ||
| 
 | ||
| 
 | ||
| class SimilarityMetric(EvaluationMetric[str, str]):
 | ||
|     """Similarity metric.
 | ||
| 
 | ||
|     Calculate the cosine similarity between a prediction and a list of contexts.
 | ||
|     """
 | ||
| 
 | ||
|     def __init__(self, **kwargs):
 | ||
|         """Create a SimilarityMetric with embeddings."""
 | ||
|         self._embeddings = kwargs.get("embeddings", None)
 | ||
|         if self._embeddings is None or not isinstance(self._embeddings, Embeddings):
 | ||
|             raise ValueError("Need embedding service!")
 | ||
| 
 | ||
|     def sync_compute(
 | ||
|         self,
 | ||
|         prediction: str,
 | ||
|         contexts: Optional[Sequence[str]] = None,
 | ||
|         query: Optional[str] = None,
 | ||
|     ) -> BaseEvaluationResult:
 | ||
|         """Compute the evaluation metric."""
 | ||
|         if not contexts:
 | ||
|             return BaseEvaluationResult(
 | ||
|                 prediction=prediction,
 | ||
|                 contexts=contexts,
 | ||
|                 score=0.0,
 | ||
|             )
 | ||
|         try:
 | ||
|             import numpy as np
 | ||
|         except ImportError:
 | ||
|             raise ImportError("numpy is required for SimilarityMetric")
 | ||
| 
 | ||
|         similarity: np.ndarray = calculate_cosine_similarity(
 | ||
|             self._embeddings, prediction, contexts
 | ||
|         )
 | ||
|         return BaseEvaluationResult(
 | ||
|             prediction=prediction,
 | ||
|             contexts=contexts,
 | ||
|             score=float(similarity.mean()),
 | ||
|         )
 | ||
| 
 | ||
| 
 | ||
| class Evaluator(ABC):
 | ||
|     """Base Evaluator class."""
 | ||
| 
 | ||
|     def __init__(
 | ||
|         self,
 | ||
|         llm_client: Optional[LLMClient] = None,
 | ||
|     ):
 | ||
|         """Create an Evaluator."""
 | ||
|         self.llm_client = llm_client
 | ||
| 
 | ||
|     @abstractmethod
 | ||
|     async def evaluate(
 | ||
|         self,
 | ||
|         dataset: DatasetType,
 | ||
|         metrics: Optional[List[EvaluationMetric]] = None,
 | ||
|         query_key: str = "query",
 | ||
|         contexts_key: str = "contexts",
 | ||
|         prediction_key: str = "prediction",
 | ||
|         parallel_num: int = 1,
 | ||
|         **kwargs,
 | ||
|     ) -> List[List[EvaluationResult]]:
 | ||
|         """Run evaluation with a dataset and metrics.
 | ||
| 
 | ||
|         Args:
 | ||
|             dataset(DatasetType): The dataset to evaluate.
 | ||
|             metrics(Optional[List[EvaluationMetric]]): The metrics to use for
 | ||
|                 evaluation.
 | ||
|             query_key(str): The key for query in the dataset.
 | ||
|             contexts_key(str): The key for contexts in the dataset.
 | ||
|             prediction_key(str): The key for prediction in the dataset.
 | ||
|             parallel_num(int): The number of parallel tasks.
 | ||
|             kwargs: Additional arguments.
 | ||
| 
 | ||
|         Returns:
 | ||
|             List[List[EvaluationResult]]: The evaluation results, the length of the
 | ||
|                 result equals to the length of the dataset. The first element in the
 | ||
|                 list is the list of evaluation results for metrics.
 | ||
|         """
 | ||
| 
 | ||
| 
 | ||
| class MetricManage:
 | ||
|     """MetricManage."""
 | ||
| 
 | ||
|     def __init__(self):
 | ||
|         """Init metricManage."""
 | ||
|         self.metrics = defaultdict()
 | ||
| 
 | ||
|     def register_metric(self, cls: Type[EvaluationMetric]):
 | ||
|         """Register metric."""
 | ||
|         self.metrics[cls.name] = cls
 | ||
| 
 | ||
|     def get_by_name(self, name: str) -> Type[EvaluationMetric]:
 | ||
|         """Get by name."""
 | ||
|         if name not in self.metrics:
 | ||
|             raise ValueError(f"Metric:{name} not register!")
 | ||
|         return self.metrics[name]
 | ||
| 
 | ||
|     def all_metric_infos(self):
 | ||
|         """Get all metric infos."""
 | ||
|         result = []
 | ||
|         for name, cls in self.metrics.items():
 | ||
|             result.append(
 | ||
|                 {
 | ||
|                     "name": name,
 | ||
|                     "describe": cls.describe,
 | ||
|                 }
 | ||
|             )
 | ||
|         return result
 | ||
| 
 | ||
| 
 | ||
| metric_mange = MetricManage()
 |