mirror of
https://github.com/csunny/DB-GPT.git
synced 2026-01-13 19:55:44 +00:00
fix(benchmark): parse multi standard anwser
This commit is contained in:
@@ -1,10 +1,10 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import asyncio
|
||||
from functools import cache
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, HTTPException, BackgroundTasks
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
@@ -36,6 +36,7 @@ router = APIRouter()
|
||||
global_system_app: Optional[SystemApp] = None
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _run_benchmark_task_sync(
|
||||
service: BenchmarkService,
|
||||
evaluate_code: str,
|
||||
@@ -61,11 +62,15 @@ def _run_benchmark_task_sync(
|
||||
model_list,
|
||||
)
|
||||
)
|
||||
logger.info(f"Benchmark task run sync finish, evaluate_code: {evaluate_code}")
|
||||
logger.info(
|
||||
f"Benchmark task run sync finish, evaluate_code: {evaluate_code}"
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Benchmark task failed for evaluate_code: {evaluate_code}, error: {str(e)}")
|
||||
logger.error(
|
||||
f"Benchmark task failed for evaluate_code: {evaluate_code}, error: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
def get_service() -> Service:
|
||||
@@ -291,10 +296,9 @@ async def execute_benchmark_task(
|
||||
)
|
||||
|
||||
# 立即返回成功响应
|
||||
return Result.succ({
|
||||
"evaluate_code": request.evaluate_code,
|
||||
"status": Status.RUNNING.value
|
||||
})
|
||||
return Result.succ(
|
||||
{"evaluate_code": request.evaluate_code, "status": Status.RUNNING.value}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/benchmark_task_list", dependencies=[Depends(check_api_key)])
|
||||
@@ -305,7 +309,7 @@ async def benchmark_task_list(
|
||||
service: BenchmarkService = Depends(get_benchmark_service),
|
||||
) -> Result:
|
||||
"""
|
||||
Query benchmark task list
|
||||
Query benchmark task list
|
||||
"""
|
||||
return Result.succ(
|
||||
service.get_list_by_page(
|
||||
@@ -344,47 +348,51 @@ async def get_benchmark_table_rows(table: str, limit: int = 10):
|
||||
|
||||
@router.get("/benchmark_result_download", dependencies=[Depends(check_api_key)])
|
||||
async def download_benchmark_result(
|
||||
evaluate_code: Optional[str] = Query(default=None, description="evaluate code"),
|
||||
service: BenchmarkService = Depends(get_benchmark_service),
|
||||
evaluate_code: Optional[str] = Query(default=None, description="evaluate code"),
|
||||
service: BenchmarkService = Depends(get_benchmark_service),
|
||||
):
|
||||
"""Download benchmark result file
|
||||
|
||||
|
||||
Args:
|
||||
evaluate_code: The evaluation code to identify the benchmark result
|
||||
service: The benchmark service instance
|
||||
|
||||
|
||||
Returns:
|
||||
StreamingResponse: File download response
|
||||
|
||||
|
||||
Raises:
|
||||
HTTPException: If evaluation code is missing or file not found
|
||||
"""
|
||||
logger.info(f"download benchmark result: {evaluate_code}")
|
||||
|
||||
|
||||
if not evaluate_code:
|
||||
raise HTTPException(status_code=400, detail="evaluate_code is required")
|
||||
|
||||
|
||||
try:
|
||||
# 获取文件名和文件流
|
||||
file_name, file_stream = await service.get_benchmark_file_stream(evaluate_code)
|
||||
|
||||
from urllib.parse import quote
|
||||
|
||||
|
||||
# 对文件名进行编码处理,支持中文和特殊字符
|
||||
encoded_filename = quote(file_name)
|
||||
|
||||
|
||||
# 返回文件下载响应
|
||||
return StreamingResponse(
|
||||
content=file_stream,
|
||||
media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}; filename={encoded_filename}",
|
||||
"Content-Type": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||
}
|
||||
"Content-Disposition": f"attachment; filename*=UTF-8''"
|
||||
f"{encoded_filename}; filename={encoded_filename}",
|
||||
"Content-Type": "application/vnd.openxmlformats-"
|
||||
"officedocument.spreadsheetml.sheet",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download benchmark result for {evaluate_code}: {str(e)}")
|
||||
logger.error(
|
||||
f"Failed to download benchmark result for {evaluate_code}: {str(e)}"
|
||||
)
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@@ -393,4 +401,4 @@ def init_endpoints(system_app: SystemApp, config: ServeConfig) -> None:
|
||||
global global_system_app
|
||||
system_app.register(Service, config=config)
|
||||
system_app.register(BenchmarkService, config=config)
|
||||
global_system_app = system_app
|
||||
global_system_app = system_app
|
||||
|
||||
@@ -91,4 +91,4 @@ class BenchmarkServeRequest(BaseModel):
|
||||
class StorageType(Enum):
|
||||
FILE = "FILE"
|
||||
OSS = "OSS"
|
||||
YU_QUE = "YU_QUE"
|
||||
YU_QUE = "YU_QUE"
|
||||
|
||||
@@ -3,7 +3,9 @@ from typing import Optional, Union
|
||||
|
||||
from dbgpt.core import HumanPromptTemplate, LLMClient, ModelMessage, ModelRequest
|
||||
from dbgpt_serve.evaluate.service.benchmark.models import ReasoningResponse
|
||||
from dbgpt_serve.evaluate.service.fetchdata.benchmark_data_manager import get_benchmark_manager
|
||||
from dbgpt_serve.evaluate.service.fetchdata.benchmark_data_manager import (
|
||||
get_benchmark_manager,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -30,7 +32,7 @@ class BenchmarkLLMTask:
|
||||
) -> Union[ReasoningResponse, None]:
|
||||
"""
|
||||
Invoke by LLM.
|
||||
|
||||
|
||||
Args:
|
||||
prompt (Optional[str]): The prompt to use for the LLM.
|
||||
**kwargs: Keyword arguments for variable replacement in prompt template.
|
||||
@@ -46,9 +48,9 @@ class BenchmarkLLMTask:
|
||||
template = HumanPromptTemplate.from_template(
|
||||
template=prompt, template_is_strict=False
|
||||
)
|
||||
if self.dialect and 'dialect' not in kwargs:
|
||||
kwargs['dialect'] = self.dialect
|
||||
|
||||
if self.dialect and "dialect" not in kwargs:
|
||||
kwargs["dialect"] = self.dialect
|
||||
|
||||
messages = template.format_messages(**kwargs)
|
||||
|
||||
# use default model if needed
|
||||
@@ -70,9 +72,11 @@ class BenchmarkLLMTask:
|
||||
return None
|
||||
|
||||
if response.has_text:
|
||||
return ReasoningResponse(cot_tokens=response.usage.get("total_tokens", 0),
|
||||
think=response.thinking_text if response.has_thinking else None,
|
||||
content=self._get_answer(response.text))
|
||||
return ReasoningResponse(
|
||||
cot_tokens=response.usage.get("total_tokens", 0),
|
||||
think=response.thinking_text if response.has_thinking else None,
|
||||
content=self._get_answer(response.text),
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@@ -7,8 +8,7 @@ import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union, Any, Tuple
|
||||
import io
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
from dbgpt.agent.core.schema import Status
|
||||
from dbgpt.component import ComponentType, SystemApp
|
||||
@@ -17,18 +17,23 @@ from dbgpt.core import LLMClient
|
||||
from dbgpt.model import DefaultLLMClient
|
||||
from dbgpt.model.cluster import WorkerManagerFactory
|
||||
from dbgpt.storage.metadata import BaseDao
|
||||
from dbgpt.util import get_or_create_event_loop, PaginationResult
|
||||
from dbgpt.util import PaginationResult, get_or_create_event_loop
|
||||
|
||||
from ....core import BaseService
|
||||
from ....prompt.service.service import Service as PromptService
|
||||
from ....rag.service.service import Service as RagService
|
||||
from ....rag.storage_manager import StorageManager
|
||||
from ...api.schemas import EvaluateServeRequest, EvaluateServeResponse, StorageType, EvaluationScene
|
||||
from ...api.schemas import (
|
||||
EvaluateServeRequest,
|
||||
EvaluateServeResponse,
|
||||
EvaluationScene,
|
||||
StorageType,
|
||||
)
|
||||
from ...config import ServeConfig
|
||||
from ...models.models import ServeDao, ServeEntity
|
||||
from .benchmark_llm_task import BenchmarkLLMTask
|
||||
from .data_compare_service import DataCompareService
|
||||
from .file_parse_service import ExcelFileParseService, FileParseService
|
||||
from .file_parse_service import ExcelFileParseService
|
||||
from .models import (
|
||||
BaseInputModel,
|
||||
BenchmarkDataSets,
|
||||
@@ -49,15 +54,13 @@ executor = ThreadPoolExecutor(max_workers=5)
|
||||
|
||||
BENCHMARK_SERVICE_COMPONENT_NAME = "dbgpt_serve_evaluate_benchmark_service"
|
||||
|
||||
# TODO 需要修改为正式文件
|
||||
STANDARD_BENCHMARK_FILE_PATH = os.path.join(
|
||||
BENCHMARK_DATA_ROOT_PATH,
|
||||
"2025_07_27_public_500_standard_benchmark_question_list_v2_local_test.xlsx"
|
||||
"2025_07_27_public_500_standard_benchmark_question_list_multi_anwser.xlsx",
|
||||
)
|
||||
|
||||
BENCHMARK_OUTPUT_RESULT_PATH = os.path.join(
|
||||
BENCHMARK_DATA_ROOT_PATH,
|
||||
"result"
|
||||
)
|
||||
BENCHMARK_OUTPUT_RESULT_PATH = os.path.join(BENCHMARK_DATA_ROOT_PATH, "result")
|
||||
|
||||
|
||||
def get_rag_service(system_app) -> RagService:
|
||||
@@ -94,8 +97,9 @@ class BenchmarkService(
|
||||
max_workers=5, thread_name_prefix="benchmark-fileWrite"
|
||||
)
|
||||
|
||||
self.output_base_file_name = f"{datetime.now().strftime('%Y%m%d')}_multi_round_benchmark_result.xlsx"
|
||||
|
||||
self.output_base_file_name = (
|
||||
f"{datetime.now().strftime('%Y%m%d%H%M')}_multi_round_benchmark_result.xlsx"
|
||||
)
|
||||
|
||||
def init_app(self, system_app: SystemApp) -> None:
|
||||
"""Initialize the service
|
||||
@@ -127,7 +131,6 @@ class BenchmarkService(
|
||||
).create()
|
||||
return DefaultLLMClient(worker_manager, True)
|
||||
|
||||
|
||||
def create_benchmark_task(
|
||||
self,
|
||||
config: BenchmarkExecuteConfig,
|
||||
@@ -135,7 +138,7 @@ class BenchmarkService(
|
||||
scene_key: str,
|
||||
scene_value: str,
|
||||
input_file_path: str,
|
||||
output_file_path: str
|
||||
output_file_path: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Save the benchmark task to the database
|
||||
@@ -155,20 +158,24 @@ class BenchmarkService(
|
||||
evaluate_code=evaluate_code,
|
||||
scene_key=scene_key,
|
||||
scene_value=scene_value,
|
||||
datasets_name=os.path.basename(input_file_path) if input_file_path else None,
|
||||
datasets_name=os.path.basename(input_file_path)
|
||||
if input_file_path
|
||||
else None,
|
||||
datasets=None,
|
||||
storage_type=StorageType.FILE.value,
|
||||
parallel_num=1,
|
||||
state=Status.RUNNING.value,
|
||||
result=output_file_path,
|
||||
context={
|
||||
"benchmark_config": json.dumps(config.to_dict(), ensure_ascii=False),
|
||||
"benchmark_config": json.dumps(
|
||||
config.to_dict(), ensure_ascii=False
|
||||
),
|
||||
},
|
||||
user_id=None,
|
||||
user_name=None,
|
||||
sys_code="benchmark_system",
|
||||
)
|
||||
|
||||
|
||||
response = self.create(request_data)
|
||||
logger.info(
|
||||
f"Successfully saved benchmark task to database: "
|
||||
@@ -184,7 +191,9 @@ class BenchmarkService(
|
||||
)
|
||||
return False
|
||||
|
||||
def _generate_output_file_full_path(self, output_file_path: str, evaluate_code: str) -> str:
|
||||
def _generate_output_file_full_path(
|
||||
self, output_file_path: str, evaluate_code: str
|
||||
) -> str:
|
||||
"""
|
||||
Generate the complete output file path,
|
||||
including the evaluate_code subfolder and default filename
|
||||
@@ -229,15 +238,22 @@ class BenchmarkService(
|
||||
if not scene_key:
|
||||
scene_key = EvaluationScene.DATASET.value
|
||||
|
||||
output_file_path = (
|
||||
self._generate_output_file_full_path(output_file_path, evaluate_code)
|
||||
output_file_path = self._generate_output_file_full_path(
|
||||
output_file_path, evaluate_code
|
||||
)
|
||||
|
||||
config = await self._build_benchmark_config(model_list, output_file_path,
|
||||
evaluate_code, scene_key)
|
||||
config = await self._build_benchmark_config(
|
||||
model_list, output_file_path, evaluate_code, scene_key
|
||||
)
|
||||
# save benchmark task
|
||||
self.create_benchmark_task(config, evaluate_code, scene_key, scene_value,
|
||||
input_file_path, output_file_path)
|
||||
self.create_benchmark_task(
|
||||
config,
|
||||
evaluate_code,
|
||||
scene_key,
|
||||
scene_value,
|
||||
input_file_path,
|
||||
output_file_path,
|
||||
)
|
||||
|
||||
# read input file
|
||||
input_list: List[BaseInputModel] = (
|
||||
@@ -281,12 +297,15 @@ class BenchmarkService(
|
||||
)
|
||||
result_list.extend(round_result_list)
|
||||
|
||||
logger.info(f"Benchmark task completed successfully for evaluate_code:"
|
||||
f" {evaluate_code}, output_file_path: {output_file_path}")
|
||||
logger.info(
|
||||
f"Benchmark task completed successfully for evaluate_code:"
|
||||
f" {evaluate_code}, output_file_path: {output_file_path}"
|
||||
)
|
||||
return result_list
|
||||
|
||||
async def _build_benchmark_config(self, model_list, output_file_path,
|
||||
evaluate_code, scene_key) -> BenchmarkExecuteConfig:
|
||||
async def _build_benchmark_config(
|
||||
self, model_list, output_file_path, evaluate_code, scene_key
|
||||
) -> BenchmarkExecuteConfig:
|
||||
config = BenchmarkExecuteConfig(
|
||||
benchmark_mode_type=BenchmarkModeTypeEnum.EXECUTE,
|
||||
standard_file_path=STANDARD_BENCHMARK_FILE_PATH,
|
||||
@@ -320,7 +339,9 @@ class BenchmarkService(
|
||||
"""
|
||||
result = BenchmarkTaskResult[OutputType]()
|
||||
result.trace_id = uuid.uuid4().hex
|
||||
result.task_id = config.evaluate_code if config.evaluate_code else uuid.uuid4().hex
|
||||
result.task_id = (
|
||||
config.evaluate_code if config.evaluate_code else uuid.uuid4().hex
|
||||
)
|
||||
result.start_time = datetime.now()
|
||||
|
||||
executor = ThreadPoolExecutor(
|
||||
@@ -389,7 +410,7 @@ class BenchmarkService(
|
||||
return result
|
||||
|
||||
async def execute(
|
||||
self, config: BenchmarkExecuteConfig, input: InputType
|
||||
self, config: BenchmarkExecuteConfig, input: InputType
|
||||
) -> Union[OutputType, None]:
|
||||
"""
|
||||
Execute LLM Benchmark Task
|
||||
@@ -405,16 +426,16 @@ class BenchmarkService(
|
||||
llm_client=self.llm_client, model_name=input.llm_code
|
||||
)
|
||||
|
||||
response: ReasoningResponse = await (
|
||||
benchmark_llm_task_service.invoke_llm(prompt=input.prompt)
|
||||
response: ReasoningResponse = await benchmark_llm_task_service.invoke_llm(
|
||||
prompt=input.prompt
|
||||
)
|
||||
|
||||
# 3. 组装评测输出
|
||||
return await self.user_input_execute_service.build_output(config, input, response)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"execute benchmark error! error: {e}"
|
||||
return await self.user_input_execute_service.build_output(
|
||||
config, input, response
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"execute benchmark error! error: {e}")
|
||||
return None
|
||||
|
||||
def check_and_trigger_batch(
|
||||
@@ -485,13 +506,18 @@ class BenchmarkService(
|
||||
batch_write_task()
|
||||
written_batches.add(batch_index)
|
||||
|
||||
def post_dispatch(self, i: int, config: BenchmarkExecuteConfig,
|
||||
input_list: List[BaseInputModel],
|
||||
output_list: List[BenchmarkTaskResult[OutputType]],
|
||||
input_file_path: str, output_file_path: str):
|
||||
def post_dispatch(
|
||||
self,
|
||||
i: int,
|
||||
config: BenchmarkExecuteConfig,
|
||||
input_list: List[BaseInputModel],
|
||||
output_list: List[BenchmarkTaskResult[OutputType]],
|
||||
input_file_path: str,
|
||||
output_file_path: str,
|
||||
):
|
||||
"""
|
||||
Post dispatch processing standard result compare LLM execute result
|
||||
and write compare result to file
|
||||
Post dispatch processing standard result compare LLM execute result
|
||||
and write compare result to file
|
||||
"""
|
||||
for j, output_result in enumerate(output_list):
|
||||
self.user_input_execute_service.post_dispatch(
|
||||
@@ -518,9 +544,13 @@ class BenchmarkService(
|
||||
List[EvaluateServeResponse]: The response
|
||||
"""
|
||||
query_request = request
|
||||
return self.dao.get_list_page(query_request, page, page_size, ServeEntity.id.name)
|
||||
return self.dao.get_list_page(
|
||||
query_request, page, page_size, ServeEntity.id.name
|
||||
)
|
||||
|
||||
async def get_benchmark_file_stream(self, evaluate_code: str) -> Tuple[str, io.BytesIO]:
|
||||
async def get_benchmark_file_stream(
|
||||
self, evaluate_code: str
|
||||
) -> Tuple[str, io.BytesIO]:
|
||||
"""Get benchmark result file stream for download
|
||||
|
||||
Args:
|
||||
@@ -539,7 +569,9 @@ class BenchmarkService(
|
||||
try:
|
||||
entity = self.dao.get_one({"evaluate_code": evaluate_code})
|
||||
if not entity:
|
||||
raise Exception(f"Evaluation record not found for code: {evaluate_code}")
|
||||
raise Exception(
|
||||
f"Evaluation record not found for code: {evaluate_code}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to query evaluation record: {e}")
|
||||
raise Exception(f"Failed to query evaluation record: {str(e)}")
|
||||
@@ -547,7 +579,9 @@ class BenchmarkService(
|
||||
# 2. 根据result的文件路径拿到文件
|
||||
file_path = entity.result
|
||||
if not file_path:
|
||||
raise Exception(f"No result file path found for evaluate_code: {evaluate_code}")
|
||||
raise Exception(
|
||||
f"No result file path found for evaluate_code: {evaluate_code}"
|
||||
)
|
||||
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists(file_path):
|
||||
@@ -555,18 +589,18 @@ class BenchmarkService(
|
||||
|
||||
try:
|
||||
# 读取文件内容到内存
|
||||
with open(file_path, 'rb') as file:
|
||||
with open(file_path, "rb") as file:
|
||||
file_content = file.read()
|
||||
|
||||
|
||||
# 创建字节流
|
||||
file_stream = io.BytesIO(file_content)
|
||||
|
||||
|
||||
# 获取文件名
|
||||
file_name = os.path.basename(file_path)
|
||||
|
||||
|
||||
logger.info(f"Successfully prepared file stream for download: {file_name}")
|
||||
return file_name, file_stream
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read result file {file_path}: {e}")
|
||||
raise Exception(f"Failed to read result file: {str(e)}")
|
||||
raise Exception(f"Failed to read result file: {str(e)}")
|
||||
|
||||
@@ -74,6 +74,7 @@ class DataCompareService:
|
||||
if not cfg.standard_result:
|
||||
return DataCompareResult.failed("leftResult is null")
|
||||
|
||||
# 对每个标准答案都进行对比,只要包含了一个标准答案,即认为结果正确,否则结果错误
|
||||
for std in cfg.standard_result:
|
||||
if not isinstance(std, dict):
|
||||
continue
|
||||
|
||||
@@ -3,11 +3,11 @@ import json
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Any, Dict
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from openpyxl import load_workbook, Workbook
|
||||
from openpyxl import Workbook, load_workbook
|
||||
|
||||
from dbgpt.util.benchmarks.ExcelUtils import ExcelUtils
|
||||
from dbgpt_serve.evaluate.db.benchmark_db import BenchmarkResultDao
|
||||
@@ -16,15 +16,16 @@ from .models import (
|
||||
AnswerExecuteModel,
|
||||
BaseInputModel,
|
||||
BenchmarkDataSets,
|
||||
BenchmarkExecuteConfig,
|
||||
DataCompareStrategyConfig,
|
||||
RoundAnswerConfirmModel, BenchmarkExecuteConfig, OutputType,
|
||||
OutputType,
|
||||
RoundAnswerConfirmModel,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileParseService(ABC):
|
||||
|
||||
def __init__(self):
|
||||
self._benchmark_dao = BenchmarkResultDao()
|
||||
|
||||
@@ -32,7 +33,7 @@ class FileParseService(ABC):
|
||||
self._column_config_file_path = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
"template",
|
||||
"benchmark_column_config_template.json"
|
||||
"benchmark_column_config_template.json",
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
@@ -174,17 +175,35 @@ class FileParseService(ABC):
|
||||
extension = Path(output_path).suffix
|
||||
if extension.lower() not in [".xlsx", ".xls"]:
|
||||
extension = ".xlsx"
|
||||
excel_file = Path(output_path).parent / f"{base_name}_round{round_id}{extension}"
|
||||
excel_file = (
|
||||
Path(output_path).parent / f"{base_name}_round{round_id}{extension}"
|
||||
)
|
||||
if not excel_file.exists():
|
||||
logger.warning(f"summary excel not found: {excel_file}")
|
||||
result = dict(right=0, wrong=0, failed=0, exception=0)
|
||||
return json.dumps(result, ensure_ascii=False)
|
||||
|
||||
df = pd.read_excel(str(excel_file), sheet_name="benchmark_compare_result")
|
||||
right = int((df["compareResult"] == "RIGHT").sum()) if "compareResult" in df.columns else 0
|
||||
wrong = int((df["compareResult"] == "WRONG").sum()) if "compareResult" in df.columns else 0
|
||||
failed = int((df["compareResult"] == "FAILED").sum()) if "compareResult" in df.columns else 0
|
||||
exception = int((df["compareResult"] == "EXCEPTION").sum()) if "compareResult" in df.columns else 0
|
||||
right = (
|
||||
int((df["compareResult"] == "RIGHT").sum())
|
||||
if "compareResult" in df.columns
|
||||
else 0
|
||||
)
|
||||
wrong = (
|
||||
int((df["compareResult"] == "WRONG").sum())
|
||||
if "compareResult" in df.columns
|
||||
else 0
|
||||
)
|
||||
failed = (
|
||||
int((df["compareResult"] == "FAILED").sum())
|
||||
if "compareResult" in df.columns
|
||||
else 0
|
||||
)
|
||||
exception = (
|
||||
int((df["compareResult"] == "EXCEPTION").sum())
|
||||
if "compareResult" in df.columns
|
||||
else 0
|
||||
)
|
||||
|
||||
result = dict(right=right, wrong=wrong, failed=failed, exception=exception)
|
||||
logger.info(
|
||||
@@ -223,10 +242,10 @@ class FileParseService(ABC):
|
||||
"""
|
||||
Parse standard benchmark sets from file.
|
||||
This method must be implemented by subclasses.
|
||||
|
||||
|
||||
Args:
|
||||
standard_excel_path: Path to the standard benchmark file
|
||||
|
||||
|
||||
Returns:
|
||||
List[AnswerExecuteModel]: List of parsed answer execute models
|
||||
"""
|
||||
@@ -234,14 +253,14 @@ class FileParseService(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def write_multi_round_benchmark_result(
|
||||
self,
|
||||
output_file_path: str,
|
||||
round_id: int,
|
||||
config: BenchmarkExecuteConfig,
|
||||
inputs: List[BaseInputModel],
|
||||
outputs: List[OutputType],
|
||||
start_index: int,
|
||||
offset: int,
|
||||
self,
|
||||
output_file_path: str,
|
||||
round_id: int,
|
||||
config: BenchmarkExecuteConfig,
|
||||
inputs: List[BaseInputModel],
|
||||
outputs: List[OutputType],
|
||||
start_index: int,
|
||||
offset: int,
|
||||
) -> bool:
|
||||
"""
|
||||
Write Benchmark Task Multi round Result
|
||||
@@ -257,9 +276,7 @@ class FileParseService(ABC):
|
||||
"""
|
||||
|
||||
|
||||
|
||||
class ExcelFileParseService(FileParseService):
|
||||
|
||||
def parse_input_sets(self, path: str) -> BenchmarkDataSets:
|
||||
"""
|
||||
Parse input sets from excel file
|
||||
@@ -327,9 +344,7 @@ class ExcelFileParseService(FileParseService):
|
||||
if workbook is not None:
|
||||
workbook.close()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"close workbook error, path: {path}, errorMsg: {e}"
|
||||
)
|
||||
logger.error(f"close workbook error, path: {path}, errorMsg: {e}")
|
||||
|
||||
return input_sets
|
||||
|
||||
@@ -355,17 +370,15 @@ class ExcelFileParseService(FileParseService):
|
||||
except Exception:
|
||||
order_by = True
|
||||
|
||||
std_result = None
|
||||
std_result: Optional[List[Dict[str, List[str]]]] = None
|
||||
if not pd.isna(row.get("标准结果")):
|
||||
try:
|
||||
std_result = json.loads(row.get("标准结果"))
|
||||
except Exception:
|
||||
std_result = None
|
||||
std_result_raw = str(row.get("标准结果"))
|
||||
std_result = self._parse_multi_standard_result(std_result_raw)
|
||||
|
||||
strategy_config = DataCompareStrategyConfig(
|
||||
strategy="CONTAIN_MATCH",
|
||||
order_by=order_by,
|
||||
standard_result=[std_result]
|
||||
standard_result=std_result
|
||||
if std_result is not None
|
||||
else None, # 使用 list
|
||||
)
|
||||
@@ -382,14 +395,14 @@ class ExcelFileParseService(FileParseService):
|
||||
return outputs
|
||||
|
||||
def write_multi_round_benchmark_result(
|
||||
self,
|
||||
output_file_path: str,
|
||||
round_id: int,
|
||||
config: BenchmarkExecuteConfig,
|
||||
inputs: List[BaseInputModel],
|
||||
outputs: List[OutputType],
|
||||
start_index: int,
|
||||
offset: int,
|
||||
self,
|
||||
output_file_path: str,
|
||||
round_id: int,
|
||||
config: BenchmarkExecuteConfig,
|
||||
inputs: List[BaseInputModel],
|
||||
outputs: List[OutputType],
|
||||
start_index: int,
|
||||
offset: int,
|
||||
) -> bool:
|
||||
"""
|
||||
Write the benchmark Result to Excel File With Multi Round
|
||||
@@ -472,8 +485,8 @@ class ExcelFileParseService(FileParseService):
|
||||
for col_idx, header in enumerate(headers, 1):
|
||||
worksheet.cell(row=1, column=col_idx, value=header)
|
||||
|
||||
# 计算写入的起始行号
|
||||
# 公式:start_index + offset + 2(+1是因为Excel行号从1开始,+1是因为表头占一行)
|
||||
# 计算写入的起始行号 公式:start_index + offset + 2
|
||||
# (+1是因为Excel行号从1开始,+1是因为表头占一行)
|
||||
write_start_row = start_index + offset + 2
|
||||
|
||||
# 写入数据行
|
||||
@@ -492,9 +505,7 @@ class ExcelFileParseService(FileParseService):
|
||||
if cell.value and len(str(cell.value)) > max_length:
|
||||
max_length = len(str(cell.value))
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"error while compute column length: {str(e)}"
|
||||
)
|
||||
logger.warning(f"error while compute column length: {str(e)}")
|
||||
# 设置列宽,最小10,最大50
|
||||
adjusted_width = min(max(max_length + 2, 10), 50)
|
||||
worksheet.column_dimensions[column_letter].width = adjusted_width
|
||||
@@ -516,13 +527,13 @@ class ExcelFileParseService(FileParseService):
|
||||
return False
|
||||
|
||||
def _get_value_by_source_type(
|
||||
self,
|
||||
field: str,
|
||||
source_type: str,
|
||||
processor_type: str,
|
||||
input_data,
|
||||
output,
|
||||
round_id: int
|
||||
self,
|
||||
field: str,
|
||||
source_type: str,
|
||||
processor_type: str,
|
||||
input_data,
|
||||
output,
|
||||
round_id: int,
|
||||
) -> Any:
|
||||
"""
|
||||
Get the value based on the source type
|
||||
@@ -576,11 +587,7 @@ class ExcelFileParseService(FileParseService):
|
||||
else ""
|
||||
)
|
||||
else:
|
||||
value = (
|
||||
str(output.executeResult)
|
||||
if output.executeResult
|
||||
else ""
|
||||
)
|
||||
value = str(output.executeResult) if output.executeResult else ""
|
||||
elif field == "errorMsg":
|
||||
value = output.errorMsg
|
||||
elif field == "traceId":
|
||||
@@ -616,6 +623,46 @@ class ExcelFileParseService(FileParseService):
|
||||
else:
|
||||
return str(value) if value is not None else ""
|
||||
|
||||
def _parse_multi_standard_result(
|
||||
self, std_result_raw: str
|
||||
) -> Optional[List[Dict[str, List[str]]]]:
|
||||
"""
|
||||
Parse multiple standard results from raw string data.
|
||||
|
||||
Handles multiple results separated by newlines and parses each line as a dict.
|
||||
|
||||
Args:
|
||||
std_result_raw (str): Raw standard result string with multiple lines
|
||||
|
||||
Returns:
|
||||
Optional[List[Dict[str, List[str]]]]: List of parsed dictionaries,
|
||||
or None if parsing fails or no valid data
|
||||
"""
|
||||
try:
|
||||
std_result_raw = std_result_raw.strip()
|
||||
if not std_result_raw:
|
||||
return None
|
||||
|
||||
# 处理多个结果,通过换行符分隔
|
||||
result_lines = std_result_raw.split("\n")
|
||||
result_list = []
|
||||
|
||||
for line in result_lines:
|
||||
line = line.strip()
|
||||
if line:
|
||||
try:
|
||||
result_list.append(json.loads(line))
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to parse line as JSON: {line}, error: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
return result_list if result_list else None
|
||||
except Exception as e:
|
||||
logger.error(f"parse multiple standard results error: {e}")
|
||||
return None
|
||||
|
||||
def _load_column_config(self) -> List[Dict]:
|
||||
"""
|
||||
Load column configuration from JSON file
|
||||
@@ -624,10 +671,12 @@ class ExcelFileParseService(FileParseService):
|
||||
List[Dict]: List of column configurations
|
||||
"""
|
||||
try:
|
||||
with open(self._column_config_file_path, 'r', encoding='utf-8') as file:
|
||||
with open(self._column_config_file_path, "r", encoding="utf-8") as file:
|
||||
config_data = json.load(file)
|
||||
return config_data.get("columns", [])
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load column configuration file: {e},"
|
||||
f" using default configuration")
|
||||
raise ValueError("Failed to load column configuration file")
|
||||
logger.error(
|
||||
f"Failed to load column configuration file: {e},"
|
||||
f" using default configuration"
|
||||
)
|
||||
raise ValueError("Failed to load column configuration file")
|
||||
|
||||
@@ -18,7 +18,12 @@ class BenchmarkModeTypeEnum(str, Enum):
|
||||
class DataCompareStrategyConfig:
|
||||
strategy: str # "EXACT_MATCH" | "CONTAIN_MATCH"
|
||||
order_by: bool = True
|
||||
standard_result: Optional[List[Dict[str, List[str]]]] = None # 改为 list[dict]
|
||||
"""
|
||||
Standard answer, each dict in the list represents a reference answer
|
||||
containing multiple columns of data. If any reference answer is matched,
|
||||
the result is considered correct
|
||||
"""
|
||||
standard_result: Optional[List[Dict[str, List[str]]]] = None
|
||||
|
||||
|
||||
class DataCompareResultEnum(str, Enum):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# app/services/user_input_execute_service.py
|
||||
import logging
|
||||
from typing import Dict, List, Union, Optional
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from dbgpt.util.benchmarks import StorageUtil
|
||||
from dbgpt_serve.evaluate.service.fetchdata.benchmark_data_manager import (
|
||||
@@ -12,13 +12,16 @@ from .file_parse_service import FileParseService
|
||||
from .models import (
|
||||
AnswerExecuteModel,
|
||||
BaseInputModel,
|
||||
BenchmarkDataSets,
|
||||
BenchmarkExecuteConfig,
|
||||
BenchmarkModeTypeEnum,
|
||||
DataCompareResultEnum,
|
||||
DataCompareStrategyConfig,
|
||||
FileParseTypeEnum,
|
||||
InputType,
|
||||
OutputType,
|
||||
ReasoningResponse,
|
||||
RoundAnswerConfirmModel, OutputType, FileParseTypeEnum, BenchmarkDataSets,
|
||||
RoundAnswerConfirmModel,
|
||||
)
|
||||
|
||||
BENCHMARK_DEFAULT_DB_SCHEMA = "ant_icube_dev."
|
||||
@@ -27,7 +30,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UserInputExecuteService:
|
||||
|
||||
def __init__(
|
||||
self, file_service: FileParseService, compare_service: DataCompareService
|
||||
):
|
||||
@@ -116,7 +118,9 @@ class UserInputExecuteService:
|
||||
confirm_list: List[RoundAnswerConfirmModel] = []
|
||||
|
||||
# compute unique llm_count across all right answers
|
||||
llm_codes = set([a.llm_code for a in right_answers if getattr(a, "llm_code", None)])
|
||||
llm_codes = set(
|
||||
[a.llm_code for a in right_answers if getattr(a, "llm_code", None)]
|
||||
)
|
||||
llm_count = len(llm_codes) if llm_codes else len(right_answers)
|
||||
|
||||
for inp in inputs:
|
||||
@@ -133,7 +137,8 @@ class UserInputExecuteService:
|
||||
standard_sql = left.llmOutput
|
||||
if config.benchmark_mode_type == BenchmarkModeTypeEnum.EXECUTE:
|
||||
strategy_cfg = left.strategyConfig
|
||||
# 优先使用左侧的执行结果作为标准答案;若无,则尝试从策略配置的 standard_result 取第一项
|
||||
# 优先使用左侧的执行结果作为标准答案;若无,
|
||||
# 则尝试从策略配置的 standard_result 取第一项
|
||||
if left.executeResult is not None:
|
||||
standard_answer = left.executeResult
|
||||
elif left.strategyConfig and left.strategyConfig.standard_result:
|
||||
@@ -149,7 +154,9 @@ class UserInputExecuteService:
|
||||
strategy_cfg = DataCompareStrategyConfig(
|
||||
strategy="EXACT_MATCH",
|
||||
order_by=True,
|
||||
standard_result=standard_result_list if standard_result_list else None,
|
||||
standard_result=standard_result_list
|
||||
if standard_result_list
|
||||
else None,
|
||||
)
|
||||
|
||||
# for each right answer (per model)
|
||||
@@ -166,7 +173,9 @@ class UserInputExecuteService:
|
||||
compare_result = DataCompareResultEnum.FAILED
|
||||
else:
|
||||
res = self.compare_service.compare(
|
||||
left if left else AnswerExecuteModel(
|
||||
left
|
||||
if left
|
||||
else AnswerExecuteModel(
|
||||
serialNo=inp.serial_no,
|
||||
analysisModelId=inp.analysis_model_id,
|
||||
question=inp.question,
|
||||
@@ -266,7 +275,7 @@ class UserInputExecuteService:
|
||||
error_msg = None
|
||||
|
||||
if config.execute_llm_result:
|
||||
logger.info("[benchmark_task] queryResult start!")
|
||||
logger.info(f"[benchmark_task] queryResult start!, sql:{sql}")
|
||||
try:
|
||||
result: List[Dict] = await get_benchmark_manager().query(sql)
|
||||
execute_result = self._convert_query_result_to_column_format(result)
|
||||
@@ -275,7 +284,7 @@ class UserInputExecuteService:
|
||||
f"[benchmark_task] queryResult error! sql = {sql}, errorMsg: {e}"
|
||||
)
|
||||
error_msg = str(e)
|
||||
logger.info(f"[benchmark_task] queryResult end! result = {execute_result}")
|
||||
logger.info(f"[benchmark_task] queryResult end!")
|
||||
|
||||
return AnswerExecuteModel(
|
||||
serialNo=input.serial_no,
|
||||
@@ -340,11 +349,11 @@ class UserInputExecuteService:
|
||||
|
||||
def _process_sql_db_schema(self, sql: str) -> str:
|
||||
"""
|
||||
Process SQL remove database schema to compatible with SQLite syntax
|
||||
Process SQL remove database schema to compatible with SQLite syntax
|
||||
"""
|
||||
if not sql or not isinstance(sql, str):
|
||||
return sql
|
||||
|
||||
|
||||
# only replace the "ant_icube_dev." prefix
|
||||
return sql.replace(BENCHMARK_DEFAULT_DB_SCHEMA, "")
|
||||
|
||||
@@ -374,5 +383,5 @@ class UserInputExecuteService:
|
||||
bool: Returns True if write is successful, False otherwise
|
||||
"""
|
||||
return self.file_service.write_multi_round_benchmark_result(
|
||||
output_file_path, round_id, config, inputs, outputs, start_index, offset)
|
||||
|
||||
output_file_path, round_id, config, inputs, outputs, start_index, offset
|
||||
)
|
||||
|
||||
@@ -29,7 +29,9 @@ class BenchmarkDataConfig(BaseModel):
|
||||
|
||||
cache_dir: str = "cache"
|
||||
db_path: str = os.path.join(BENCHMARK_DATA_ROOT_PATH, "ant_icube_dev.db")
|
||||
table_mapping_file: str = os.path.join(BENCHMARK_DATA_ROOT_PATH, "table_mapping.json")
|
||||
table_mapping_file: str = os.path.join(
|
||||
BENCHMARK_DATA_ROOT_PATH, "table_mapping.json"
|
||||
)
|
||||
cache_expiry_days: int = 1
|
||||
repo_url: str = "https://github.com/inclusionAI/Falcon"
|
||||
data_dir: str = "data/source"
|
||||
|
||||
Binary file not shown.
Reference in New Issue
Block a user