mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-02 00:28:00 +00:00
Co-authored-by: 夏姜 <wenfengjiang.jwf@digital-engine.com> Co-authored-by: aries_ckt <916701291@qq.com> Co-authored-by: wb-lh513319 <wb-lh513319@alibaba-inc.com> Co-authored-by: csunny <cfqsunny@163.com>
121 lines
3.8 KiB
Python
121 lines
3.8 KiB
Python
import json
|
||
import logging
|
||
from abc import ABC
|
||
from typing import Any, List, Optional
|
||
|
||
from dbgpt.core.interface.evaluation import (
|
||
BaseEvaluationResult,
|
||
EvaluationMetric,
|
||
metric_mange,
|
||
)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class AppLinkMetric(EvaluationMetric[str, str], ABC):
|
||
"""Intent evaluation metric.
|
||
Hit rate calculates the fraction of queries where the correct answer is found
|
||
within the top-k retrieved documents. In simpler terms, it’s about how often our
|
||
system gets it right within the top few guesses.
|
||
"""
|
||
|
||
@classmethod
|
||
@property
|
||
def describe(cls) -> str:
|
||
return "可以对AppLink的返回结果进行正确性判断计算!"
|
||
|
||
def sync_compute(
|
||
self,
|
||
prediction: Optional[str] = None,
|
||
contexts: Optional[str] = None,
|
||
query: Optional[str] = None,
|
||
**kwargs: Any,
|
||
) -> BaseEvaluationResult:
|
||
"""Compute Intent metric.
|
||
Args:
|
||
prediction(Optional[str]): The retrieved chunks from the retriever.
|
||
contexts(Optional[str]): The contexts from dataset.
|
||
query:(Optional[str]) The query text.
|
||
Returns:
|
||
BaseEvaluationResult: The evaluation result.
|
||
"""
|
||
score = 0
|
||
prediction_result = None
|
||
passing = True
|
||
try:
|
||
prediction_result = prediction
|
||
if not prediction or len(prediction) <= 0:
|
||
passing = False
|
||
else:
|
||
prediction_dict = json.loads(prediction)
|
||
intent = prediction_dict.get("app_name", None)
|
||
prediction_result = intent
|
||
if intent in contexts:
|
||
score = 1
|
||
except Exception as e:
|
||
logger.warning(f"AppLinkMetric sync_compute exception {str(e)}")
|
||
if prediction == contexts:
|
||
score = 1
|
||
|
||
return BaseEvaluationResult(
|
||
score=score,
|
||
prediction=prediction_result,
|
||
passing=passing,
|
||
)
|
||
|
||
|
||
class IntentMetric(EvaluationMetric[str, str], ABC):
|
||
"""Intent evaluation metric.
|
||
Hit rate calculates the fraction of queries where the correct answer is found
|
||
within the top-k retrieved documents. In simpler terms, it’s about how often our
|
||
system gets it right within the top few guesses.
|
||
"""
|
||
|
||
@classmethod
|
||
@property
|
||
def describe(cls) -> str:
|
||
return "可以对意图识别Agent的返回结果进行正确性判断计算!"
|
||
|
||
def sync_compute(
|
||
self,
|
||
prediction: Optional[str] = None,
|
||
contexts: Optional[str] = None,
|
||
query: Optional[str] = None,
|
||
**kwargs: Any,
|
||
) -> BaseEvaluationResult:
|
||
"""Compute Intent metric.
|
||
Args:
|
||
prediction(Optional[str]): The retrieved chunks from the retriever.
|
||
contexts(Optional[str]): The contexts from dataset.
|
||
query:(Optional[str]) The query text.
|
||
Returns:
|
||
BaseEvaluationResult: The evaluation result.
|
||
"""
|
||
score = 0
|
||
prediction_result = None
|
||
try:
|
||
prediction_result = prediction
|
||
if not prediction:
|
||
passing = False
|
||
else:
|
||
prediction_dict = json.loads(prediction)
|
||
intent = prediction_dict.get("intent", None)
|
||
prediction_result = intent
|
||
if intent in contexts:
|
||
score = 1
|
||
passing = True
|
||
except Exception as e:
|
||
print(f"warning {str(e)}")
|
||
if prediction == contexts:
|
||
score = 1
|
||
|
||
return BaseEvaluationResult(
|
||
score=score,
|
||
prediction=prediction_result,
|
||
passing=passing,
|
||
)
|
||
|
||
|
||
metric_mange.register_metric(IntentMetric)
|
||
metric_mange.register_metric(AppLinkMetric)
|