fix(benchmark): parse multi standard anwser

This commit is contained in:
alan.cl
2025-10-15 10:23:27 +08:00
parent 5df8d94f43
commit 92243cb6bc
10 changed files with 273 additions and 161 deletions

View File

@@ -1,10 +1,10 @@
import asyncio
import json
import logging
import asyncio
from functools import cache
from typing import List, Optional
from fastapi import APIRouter, Depends, Query, HTTPException, BackgroundTasks
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query
from fastapi.responses import StreamingResponse
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
@@ -36,6 +36,7 @@ router = APIRouter()
global_system_app: Optional[SystemApp] = None
logger = logging.getLogger(__name__)
def _run_benchmark_task_sync(
service: BenchmarkService,
evaluate_code: str,
@@ -61,11 +62,15 @@ def _run_benchmark_task_sync(
model_list,
)
)
logger.info(f"Benchmark task run sync finish, evaluate_code: {evaluate_code}")
logger.info(
f"Benchmark task run sync finish, evaluate_code: {evaluate_code}"
)
finally:
loop.close()
except Exception as e:
logger.error(f"Benchmark task failed for evaluate_code: {evaluate_code}, error: {str(e)}")
logger.error(
f"Benchmark task failed for evaluate_code: {evaluate_code}, error: {str(e)}"
)
def get_service() -> Service:
@@ -291,10 +296,9 @@ async def execute_benchmark_task(
)
# 立即返回成功响应
return Result.succ({
"evaluate_code": request.evaluate_code,
"status": Status.RUNNING.value
})
return Result.succ(
{"evaluate_code": request.evaluate_code, "status": Status.RUNNING.value}
)
@router.get("/benchmark_task_list", dependencies=[Depends(check_api_key)])
@@ -305,7 +309,7 @@ async def benchmark_task_list(
service: BenchmarkService = Depends(get_benchmark_service),
) -> Result:
"""
Query benchmark task list
Query benchmark task list
"""
return Result.succ(
service.get_list_by_page(
@@ -344,47 +348,51 @@ async def get_benchmark_table_rows(table: str, limit: int = 10):
@router.get("/benchmark_result_download", dependencies=[Depends(check_api_key)])
async def download_benchmark_result(
evaluate_code: Optional[str] = Query(default=None, description="evaluate code"),
service: BenchmarkService = Depends(get_benchmark_service),
evaluate_code: Optional[str] = Query(default=None, description="evaluate code"),
service: BenchmarkService = Depends(get_benchmark_service),
):
"""Download benchmark result file
Args:
evaluate_code: The evaluation code to identify the benchmark result
service: The benchmark service instance
Returns:
StreamingResponse: File download response
Raises:
HTTPException: If evaluation code is missing or file not found
"""
logger.info(f"download benchmark result: {evaluate_code}")
if not evaluate_code:
raise HTTPException(status_code=400, detail="evaluate_code is required")
try:
# 获取文件名和文件流
file_name, file_stream = await service.get_benchmark_file_stream(evaluate_code)
from urllib.parse import quote
# 对文件名进行编码处理,支持中文和特殊字符
encoded_filename = quote(file_name)
# 返回文件下载响应
return StreamingResponse(
content=file_stream,
media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
headers={
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}; filename={encoded_filename}",
"Content-Type": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
}
"Content-Disposition": f"attachment; filename*=UTF-8''"
f"{encoded_filename}; filename={encoded_filename}",
"Content-Type": "application/vnd.openxmlformats-"
"officedocument.spreadsheetml.sheet",
},
)
except Exception as e:
logger.error(f"Failed to download benchmark result for {evaluate_code}: {str(e)}")
logger.error(
f"Failed to download benchmark result for {evaluate_code}: {str(e)}"
)
raise HTTPException(status_code=404, detail=str(e))
@@ -393,4 +401,4 @@ def init_endpoints(system_app: SystemApp, config: ServeConfig) -> None:
global global_system_app
system_app.register(Service, config=config)
system_app.register(BenchmarkService, config=config)
global_system_app = system_app
global_system_app = system_app

View File

@@ -91,4 +91,4 @@ class BenchmarkServeRequest(BaseModel):
class StorageType(Enum):
FILE = "FILE"
OSS = "OSS"
YU_QUE = "YU_QUE"
YU_QUE = "YU_QUE"

View File

@@ -3,7 +3,9 @@ from typing import Optional, Union
from dbgpt.core import HumanPromptTemplate, LLMClient, ModelMessage, ModelRequest
from dbgpt_serve.evaluate.service.benchmark.models import ReasoningResponse
from dbgpt_serve.evaluate.service.fetchdata.benchmark_data_manager import get_benchmark_manager
from dbgpt_serve.evaluate.service.fetchdata.benchmark_data_manager import (
get_benchmark_manager,
)
logger = logging.getLogger(__name__)
@@ -30,7 +32,7 @@ class BenchmarkLLMTask:
) -> Union[ReasoningResponse, None]:
"""
Invoke by LLM.
Args:
prompt (Optional[str]): The prompt to use for the LLM.
**kwargs: Keyword arguments for variable replacement in prompt template.
@@ -46,9 +48,9 @@ class BenchmarkLLMTask:
template = HumanPromptTemplate.from_template(
template=prompt, template_is_strict=False
)
if self.dialect and 'dialect' not in kwargs:
kwargs['dialect'] = self.dialect
if self.dialect and "dialect" not in kwargs:
kwargs["dialect"] = self.dialect
messages = template.format_messages(**kwargs)
# use default model if needed
@@ -70,9 +72,11 @@ class BenchmarkLLMTask:
return None
if response.has_text:
return ReasoningResponse(cot_tokens=response.usage.get("total_tokens", 0),
think=response.thinking_text if response.has_thinking else None,
content=self._get_answer(response.text))
return ReasoningResponse(
cot_tokens=response.usage.get("total_tokens", 0),
think=response.thinking_text if response.has_thinking else None,
content=self._get_answer(response.text),
)
else:
return None

View File

@@ -1,3 +1,4 @@
import io
import json
import logging
import os
@@ -7,8 +8,7 @@ import uuid
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Union, Any, Tuple
import io
from typing import Dict, List, Optional, Tuple, Union
from dbgpt.agent.core.schema import Status
from dbgpt.component import ComponentType, SystemApp
@@ -17,18 +17,23 @@ from dbgpt.core import LLMClient
from dbgpt.model import DefaultLLMClient
from dbgpt.model.cluster import WorkerManagerFactory
from dbgpt.storage.metadata import BaseDao
from dbgpt.util import get_or_create_event_loop, PaginationResult
from dbgpt.util import PaginationResult, get_or_create_event_loop
from ....core import BaseService
from ....prompt.service.service import Service as PromptService
from ....rag.service.service import Service as RagService
from ....rag.storage_manager import StorageManager
from ...api.schemas import EvaluateServeRequest, EvaluateServeResponse, StorageType, EvaluationScene
from ...api.schemas import (
EvaluateServeRequest,
EvaluateServeResponse,
EvaluationScene,
StorageType,
)
from ...config import ServeConfig
from ...models.models import ServeDao, ServeEntity
from .benchmark_llm_task import BenchmarkLLMTask
from .data_compare_service import DataCompareService
from .file_parse_service import ExcelFileParseService, FileParseService
from .file_parse_service import ExcelFileParseService
from .models import (
BaseInputModel,
BenchmarkDataSets,
@@ -49,15 +54,13 @@ executor = ThreadPoolExecutor(max_workers=5)
BENCHMARK_SERVICE_COMPONENT_NAME = "dbgpt_serve_evaluate_benchmark_service"
# TODO 需要修改为正式文件
STANDARD_BENCHMARK_FILE_PATH = os.path.join(
BENCHMARK_DATA_ROOT_PATH,
"2025_07_27_public_500_standard_benchmark_question_list_v2_local_test.xlsx"
"2025_07_27_public_500_standard_benchmark_question_list_multi_anwser.xlsx",
)
BENCHMARK_OUTPUT_RESULT_PATH = os.path.join(
BENCHMARK_DATA_ROOT_PATH,
"result"
)
BENCHMARK_OUTPUT_RESULT_PATH = os.path.join(BENCHMARK_DATA_ROOT_PATH, "result")
def get_rag_service(system_app) -> RagService:
@@ -94,8 +97,9 @@ class BenchmarkService(
max_workers=5, thread_name_prefix="benchmark-fileWrite"
)
self.output_base_file_name = f"{datetime.now().strftime('%Y%m%d')}_multi_round_benchmark_result.xlsx"
self.output_base_file_name = (
f"{datetime.now().strftime('%Y%m%d%H%M')}_multi_round_benchmark_result.xlsx"
)
def init_app(self, system_app: SystemApp) -> None:
"""Initialize the service
@@ -127,7 +131,6 @@ class BenchmarkService(
).create()
return DefaultLLMClient(worker_manager, True)
def create_benchmark_task(
self,
config: BenchmarkExecuteConfig,
@@ -135,7 +138,7 @@ class BenchmarkService(
scene_key: str,
scene_value: str,
input_file_path: str,
output_file_path: str
output_file_path: str,
) -> bool:
"""
Save the benchmark task to the database
@@ -155,20 +158,24 @@ class BenchmarkService(
evaluate_code=evaluate_code,
scene_key=scene_key,
scene_value=scene_value,
datasets_name=os.path.basename(input_file_path) if input_file_path else None,
datasets_name=os.path.basename(input_file_path)
if input_file_path
else None,
datasets=None,
storage_type=StorageType.FILE.value,
parallel_num=1,
state=Status.RUNNING.value,
result=output_file_path,
context={
"benchmark_config": json.dumps(config.to_dict(), ensure_ascii=False),
"benchmark_config": json.dumps(
config.to_dict(), ensure_ascii=False
),
},
user_id=None,
user_name=None,
sys_code="benchmark_system",
)
response = self.create(request_data)
logger.info(
f"Successfully saved benchmark task to database: "
@@ -184,7 +191,9 @@ class BenchmarkService(
)
return False
def _generate_output_file_full_path(self, output_file_path: str, evaluate_code: str) -> str:
def _generate_output_file_full_path(
self, output_file_path: str, evaluate_code: str
) -> str:
"""
Generate the complete output file path,
including the evaluate_code subfolder and default filename
@@ -229,15 +238,22 @@ class BenchmarkService(
if not scene_key:
scene_key = EvaluationScene.DATASET.value
output_file_path = (
self._generate_output_file_full_path(output_file_path, evaluate_code)
output_file_path = self._generate_output_file_full_path(
output_file_path, evaluate_code
)
config = await self._build_benchmark_config(model_list, output_file_path,
evaluate_code, scene_key)
config = await self._build_benchmark_config(
model_list, output_file_path, evaluate_code, scene_key
)
# save benchmark task
self.create_benchmark_task(config, evaluate_code, scene_key, scene_value,
input_file_path, output_file_path)
self.create_benchmark_task(
config,
evaluate_code,
scene_key,
scene_value,
input_file_path,
output_file_path,
)
# read input file
input_list: List[BaseInputModel] = (
@@ -281,12 +297,15 @@ class BenchmarkService(
)
result_list.extend(round_result_list)
logger.info(f"Benchmark task completed successfully for evaluate_code:"
f" {evaluate_code}, output_file_path: {output_file_path}")
logger.info(
f"Benchmark task completed successfully for evaluate_code:"
f" {evaluate_code}, output_file_path: {output_file_path}"
)
return result_list
async def _build_benchmark_config(self, model_list, output_file_path,
evaluate_code, scene_key) -> BenchmarkExecuteConfig:
async def _build_benchmark_config(
self, model_list, output_file_path, evaluate_code, scene_key
) -> BenchmarkExecuteConfig:
config = BenchmarkExecuteConfig(
benchmark_mode_type=BenchmarkModeTypeEnum.EXECUTE,
standard_file_path=STANDARD_BENCHMARK_FILE_PATH,
@@ -320,7 +339,9 @@ class BenchmarkService(
"""
result = BenchmarkTaskResult[OutputType]()
result.trace_id = uuid.uuid4().hex
result.task_id = config.evaluate_code if config.evaluate_code else uuid.uuid4().hex
result.task_id = (
config.evaluate_code if config.evaluate_code else uuid.uuid4().hex
)
result.start_time = datetime.now()
executor = ThreadPoolExecutor(
@@ -389,7 +410,7 @@ class BenchmarkService(
return result
async def execute(
self, config: BenchmarkExecuteConfig, input: InputType
self, config: BenchmarkExecuteConfig, input: InputType
) -> Union[OutputType, None]:
"""
Execute LLM Benchmark Task
@@ -405,16 +426,16 @@ class BenchmarkService(
llm_client=self.llm_client, model_name=input.llm_code
)
response: ReasoningResponse = await (
benchmark_llm_task_service.invoke_llm(prompt=input.prompt)
response: ReasoningResponse = await benchmark_llm_task_service.invoke_llm(
prompt=input.prompt
)
# 3. 组装评测输出
return await self.user_input_execute_service.build_output(config, input, response)
except Exception as e:
logger.error(
f"execute benchmark error! error: {e}"
return await self.user_input_execute_service.build_output(
config, input, response
)
except Exception as e:
logger.error(f"execute benchmark error! error: {e}")
return None
def check_and_trigger_batch(
@@ -485,13 +506,18 @@ class BenchmarkService(
batch_write_task()
written_batches.add(batch_index)
def post_dispatch(self, i: int, config: BenchmarkExecuteConfig,
input_list: List[BaseInputModel],
output_list: List[BenchmarkTaskResult[OutputType]],
input_file_path: str, output_file_path: str):
def post_dispatch(
self,
i: int,
config: BenchmarkExecuteConfig,
input_list: List[BaseInputModel],
output_list: List[BenchmarkTaskResult[OutputType]],
input_file_path: str,
output_file_path: str,
):
"""
Post dispatch processing standard result compare LLM execute result
and write compare result to file
Post dispatch processing standard result compare LLM execute result
and write compare result to file
"""
for j, output_result in enumerate(output_list):
self.user_input_execute_service.post_dispatch(
@@ -518,9 +544,13 @@ class BenchmarkService(
List[EvaluateServeResponse]: The response
"""
query_request = request
return self.dao.get_list_page(query_request, page, page_size, ServeEntity.id.name)
return self.dao.get_list_page(
query_request, page, page_size, ServeEntity.id.name
)
async def get_benchmark_file_stream(self, evaluate_code: str) -> Tuple[str, io.BytesIO]:
async def get_benchmark_file_stream(
self, evaluate_code: str
) -> Tuple[str, io.BytesIO]:
"""Get benchmark result file stream for download
Args:
@@ -539,7 +569,9 @@ class BenchmarkService(
try:
entity = self.dao.get_one({"evaluate_code": evaluate_code})
if not entity:
raise Exception(f"Evaluation record not found for code: {evaluate_code}")
raise Exception(
f"Evaluation record not found for code: {evaluate_code}"
)
except Exception as e:
logger.error(f"Failed to query evaluation record: {e}")
raise Exception(f"Failed to query evaluation record: {str(e)}")
@@ -547,7 +579,9 @@ class BenchmarkService(
# 2. 根据result的文件路径拿到文件
file_path = entity.result
if not file_path:
raise Exception(f"No result file path found for evaluate_code: {evaluate_code}")
raise Exception(
f"No result file path found for evaluate_code: {evaluate_code}"
)
# 检查文件是否存在
if not os.path.exists(file_path):
@@ -555,18 +589,18 @@ class BenchmarkService(
try:
# 读取文件内容到内存
with open(file_path, 'rb') as file:
with open(file_path, "rb") as file:
file_content = file.read()
# 创建字节流
file_stream = io.BytesIO(file_content)
# 获取文件名
file_name = os.path.basename(file_path)
logger.info(f"Successfully prepared file stream for download: {file_name}")
return file_name, file_stream
except Exception as e:
logger.error(f"Failed to read result file {file_path}: {e}")
raise Exception(f"Failed to read result file: {str(e)}")
raise Exception(f"Failed to read result file: {str(e)}")

View File

@@ -74,6 +74,7 @@ class DataCompareService:
if not cfg.standard_result:
return DataCompareResult.failed("leftResult is null")
# 对每个标准答案都进行对比,只要包含了一个标准答案,即认为结果正确,否则结果错误
for std in cfg.standard_result:
if not isinstance(std, dict):
continue

View File

@@ -3,11 +3,11 @@ import json
import logging
import os
from abc import ABC, abstractmethod
from typing import List, Any, Dict
from pathlib import Path
from typing import Any, Dict, List, Optional
import pandas as pd
from pathlib import Path
from openpyxl import load_workbook, Workbook
from openpyxl import Workbook, load_workbook
from dbgpt.util.benchmarks.ExcelUtils import ExcelUtils
from dbgpt_serve.evaluate.db.benchmark_db import BenchmarkResultDao
@@ -16,15 +16,16 @@ from .models import (
AnswerExecuteModel,
BaseInputModel,
BenchmarkDataSets,
BenchmarkExecuteConfig,
DataCompareStrategyConfig,
RoundAnswerConfirmModel, BenchmarkExecuteConfig, OutputType,
OutputType,
RoundAnswerConfirmModel,
)
logger = logging.getLogger(__name__)
class FileParseService(ABC):
def __init__(self):
self._benchmark_dao = BenchmarkResultDao()
@@ -32,7 +33,7 @@ class FileParseService(ABC):
self._column_config_file_path = os.path.join(
os.path.dirname(__file__),
"template",
"benchmark_column_config_template.json"
"benchmark_column_config_template.json",
)
@abstractmethod
@@ -174,17 +175,35 @@ class FileParseService(ABC):
extension = Path(output_path).suffix
if extension.lower() not in [".xlsx", ".xls"]:
extension = ".xlsx"
excel_file = Path(output_path).parent / f"{base_name}_round{round_id}{extension}"
excel_file = (
Path(output_path).parent / f"{base_name}_round{round_id}{extension}"
)
if not excel_file.exists():
logger.warning(f"summary excel not found: {excel_file}")
result = dict(right=0, wrong=0, failed=0, exception=0)
return json.dumps(result, ensure_ascii=False)
df = pd.read_excel(str(excel_file), sheet_name="benchmark_compare_result")
right = int((df["compareResult"] == "RIGHT").sum()) if "compareResult" in df.columns else 0
wrong = int((df["compareResult"] == "WRONG").sum()) if "compareResult" in df.columns else 0
failed = int((df["compareResult"] == "FAILED").sum()) if "compareResult" in df.columns else 0
exception = int((df["compareResult"] == "EXCEPTION").sum()) if "compareResult" in df.columns else 0
right = (
int((df["compareResult"] == "RIGHT").sum())
if "compareResult" in df.columns
else 0
)
wrong = (
int((df["compareResult"] == "WRONG").sum())
if "compareResult" in df.columns
else 0
)
failed = (
int((df["compareResult"] == "FAILED").sum())
if "compareResult" in df.columns
else 0
)
exception = (
int((df["compareResult"] == "EXCEPTION").sum())
if "compareResult" in df.columns
else 0
)
result = dict(right=right, wrong=wrong, failed=failed, exception=exception)
logger.info(
@@ -223,10 +242,10 @@ class FileParseService(ABC):
"""
Parse standard benchmark sets from file.
This method must be implemented by subclasses.
Args:
standard_excel_path: Path to the standard benchmark file
Returns:
List[AnswerExecuteModel]: List of parsed answer execute models
"""
@@ -234,14 +253,14 @@ class FileParseService(ABC):
@abstractmethod
def write_multi_round_benchmark_result(
self,
output_file_path: str,
round_id: int,
config: BenchmarkExecuteConfig,
inputs: List[BaseInputModel],
outputs: List[OutputType],
start_index: int,
offset: int,
self,
output_file_path: str,
round_id: int,
config: BenchmarkExecuteConfig,
inputs: List[BaseInputModel],
outputs: List[OutputType],
start_index: int,
offset: int,
) -> bool:
"""
Write Benchmark Task Multi round Result
@@ -257,9 +276,7 @@ class FileParseService(ABC):
"""
class ExcelFileParseService(FileParseService):
def parse_input_sets(self, path: str) -> BenchmarkDataSets:
"""
Parse input sets from excel file
@@ -327,9 +344,7 @@ class ExcelFileParseService(FileParseService):
if workbook is not None:
workbook.close()
except Exception as e:
logger.error(
f"close workbook error, path: {path}, errorMsg: {e}"
)
logger.error(f"close workbook error, path: {path}, errorMsg: {e}")
return input_sets
@@ -355,17 +370,15 @@ class ExcelFileParseService(FileParseService):
except Exception:
order_by = True
std_result = None
std_result: Optional[List[Dict[str, List[str]]]] = None
if not pd.isna(row.get("标准结果")):
try:
std_result = json.loads(row.get("标准结果"))
except Exception:
std_result = None
std_result_raw = str(row.get("标准结果"))
std_result = self._parse_multi_standard_result(std_result_raw)
strategy_config = DataCompareStrategyConfig(
strategy="CONTAIN_MATCH",
order_by=order_by,
standard_result=[std_result]
standard_result=std_result
if std_result is not None
else None, # 使用 list
)
@@ -382,14 +395,14 @@ class ExcelFileParseService(FileParseService):
return outputs
def write_multi_round_benchmark_result(
self,
output_file_path: str,
round_id: int,
config: BenchmarkExecuteConfig,
inputs: List[BaseInputModel],
outputs: List[OutputType],
start_index: int,
offset: int,
self,
output_file_path: str,
round_id: int,
config: BenchmarkExecuteConfig,
inputs: List[BaseInputModel],
outputs: List[OutputType],
start_index: int,
offset: int,
) -> bool:
"""
Write the benchmark Result to Excel File With Multi Round
@@ -472,8 +485,8 @@ class ExcelFileParseService(FileParseService):
for col_idx, header in enumerate(headers, 1):
worksheet.cell(row=1, column=col_idx, value=header)
# 计算写入的起始行号
# 公式start_index + offset + 2+1是因为Excel行号从1开始+1是因为表头占一行
# 计算写入的起始行号 公式start_index + offset + 2
# (+1是因为Excel行号从1开始+1是因为表头占一行)
write_start_row = start_index + offset + 2
# 写入数据行
@@ -492,9 +505,7 @@ class ExcelFileParseService(FileParseService):
if cell.value and len(str(cell.value)) > max_length:
max_length = len(str(cell.value))
except Exception as e:
logger.warning(
f"error while compute column length: {str(e)}"
)
logger.warning(f"error while compute column length: {str(e)}")
# 设置列宽最小10最大50
adjusted_width = min(max(max_length + 2, 10), 50)
worksheet.column_dimensions[column_letter].width = adjusted_width
@@ -516,13 +527,13 @@ class ExcelFileParseService(FileParseService):
return False
def _get_value_by_source_type(
self,
field: str,
source_type: str,
processor_type: str,
input_data,
output,
round_id: int
self,
field: str,
source_type: str,
processor_type: str,
input_data,
output,
round_id: int,
) -> Any:
"""
Get the value based on the source type
@@ -576,11 +587,7 @@ class ExcelFileParseService(FileParseService):
else ""
)
else:
value = (
str(output.executeResult)
if output.executeResult
else ""
)
value = str(output.executeResult) if output.executeResult else ""
elif field == "errorMsg":
value = output.errorMsg
elif field == "traceId":
@@ -616,6 +623,46 @@ class ExcelFileParseService(FileParseService):
else:
return str(value) if value is not None else ""
def _parse_multi_standard_result(
self, std_result_raw: str
) -> Optional[List[Dict[str, List[str]]]]:
"""
Parse multiple standard results from raw string data.
Handles multiple results separated by newlines and parses each line as a dict.
Args:
std_result_raw (str): Raw standard result string with multiple lines
Returns:
Optional[List[Dict[str, List[str]]]]: List of parsed dictionaries,
or None if parsing fails or no valid data
"""
try:
std_result_raw = std_result_raw.strip()
if not std_result_raw:
return None
# 处理多个结果,通过换行符分隔
result_lines = std_result_raw.split("\n")
result_list = []
for line in result_lines:
line = line.strip()
if line:
try:
result_list.append(json.loads(line))
except Exception as e:
logger.warning(
f"Failed to parse line as JSON: {line}, error: {e}"
)
continue
return result_list if result_list else None
except Exception as e:
logger.error(f"parse multiple standard results error: {e}")
return None
def _load_column_config(self) -> List[Dict]:
"""
Load column configuration from JSON file
@@ -624,10 +671,12 @@ class ExcelFileParseService(FileParseService):
List[Dict]: List of column configurations
"""
try:
with open(self._column_config_file_path, 'r', encoding='utf-8') as file:
with open(self._column_config_file_path, "r", encoding="utf-8") as file:
config_data = json.load(file)
return config_data.get("columns", [])
except Exception as e:
logger.error(f"Failed to load column configuration file: {e},"
f" using default configuration")
raise ValueError("Failed to load column configuration file")
logger.error(
f"Failed to load column configuration file: {e},"
f" using default configuration"
)
raise ValueError("Failed to load column configuration file")

View File

@@ -18,7 +18,12 @@ class BenchmarkModeTypeEnum(str, Enum):
class DataCompareStrategyConfig:
strategy: str # "EXACT_MATCH" | "CONTAIN_MATCH"
order_by: bool = True
standard_result: Optional[List[Dict[str, List[str]]]] = None # 改为 list[dict]
"""
Standard answer, each dict in the list represents a reference answer
containing multiple columns of data. If any reference answer is matched,
the result is considered correct
"""
standard_result: Optional[List[Dict[str, List[str]]]] = None
class DataCompareResultEnum(str, Enum):

View File

@@ -1,6 +1,6 @@
# app/services/user_input_execute_service.py
import logging
from typing import Dict, List, Union, Optional
from typing import Dict, List, Optional, Union
from dbgpt.util.benchmarks import StorageUtil
from dbgpt_serve.evaluate.service.fetchdata.benchmark_data_manager import (
@@ -12,13 +12,16 @@ from .file_parse_service import FileParseService
from .models import (
AnswerExecuteModel,
BaseInputModel,
BenchmarkDataSets,
BenchmarkExecuteConfig,
BenchmarkModeTypeEnum,
DataCompareResultEnum,
DataCompareStrategyConfig,
FileParseTypeEnum,
InputType,
OutputType,
ReasoningResponse,
RoundAnswerConfirmModel, OutputType, FileParseTypeEnum, BenchmarkDataSets,
RoundAnswerConfirmModel,
)
BENCHMARK_DEFAULT_DB_SCHEMA = "ant_icube_dev."
@@ -27,7 +30,6 @@ logger = logging.getLogger(__name__)
class UserInputExecuteService:
def __init__(
self, file_service: FileParseService, compare_service: DataCompareService
):
@@ -116,7 +118,9 @@ class UserInputExecuteService:
confirm_list: List[RoundAnswerConfirmModel] = []
# compute unique llm_count across all right answers
llm_codes = set([a.llm_code for a in right_answers if getattr(a, "llm_code", None)])
llm_codes = set(
[a.llm_code for a in right_answers if getattr(a, "llm_code", None)]
)
llm_count = len(llm_codes) if llm_codes else len(right_answers)
for inp in inputs:
@@ -133,7 +137,8 @@ class UserInputExecuteService:
standard_sql = left.llmOutput
if config.benchmark_mode_type == BenchmarkModeTypeEnum.EXECUTE:
strategy_cfg = left.strategyConfig
# 优先使用左侧的执行结果作为标准答案;若无,则尝试从策略配置的 standard_result 取第一项
# 优先使用左侧的执行结果作为标准答案;若无,
# 则尝试从策略配置的 standard_result 取第一项
if left.executeResult is not None:
standard_answer = left.executeResult
elif left.strategyConfig and left.strategyConfig.standard_result:
@@ -149,7 +154,9 @@ class UserInputExecuteService:
strategy_cfg = DataCompareStrategyConfig(
strategy="EXACT_MATCH",
order_by=True,
standard_result=standard_result_list if standard_result_list else None,
standard_result=standard_result_list
if standard_result_list
else None,
)
# for each right answer (per model)
@@ -166,7 +173,9 @@ class UserInputExecuteService:
compare_result = DataCompareResultEnum.FAILED
else:
res = self.compare_service.compare(
left if left else AnswerExecuteModel(
left
if left
else AnswerExecuteModel(
serialNo=inp.serial_no,
analysisModelId=inp.analysis_model_id,
question=inp.question,
@@ -266,7 +275,7 @@ class UserInputExecuteService:
error_msg = None
if config.execute_llm_result:
logger.info("[benchmark_task] queryResult start!")
logger.info(f"[benchmark_task] queryResult start!, sql:{sql}")
try:
result: List[Dict] = await get_benchmark_manager().query(sql)
execute_result = self._convert_query_result_to_column_format(result)
@@ -275,7 +284,7 @@ class UserInputExecuteService:
f"[benchmark_task] queryResult error! sql = {sql}, errorMsg: {e}"
)
error_msg = str(e)
logger.info(f"[benchmark_task] queryResult end! result = {execute_result}")
logger.info(f"[benchmark_task] queryResult end!")
return AnswerExecuteModel(
serialNo=input.serial_no,
@@ -340,11 +349,11 @@ class UserInputExecuteService:
def _process_sql_db_schema(self, sql: str) -> str:
"""
Process SQL remove database schema to compatible with SQLite syntax
Process SQL remove database schema to compatible with SQLite syntax
"""
if not sql or not isinstance(sql, str):
return sql
# only replace the "ant_icube_dev." prefix
return sql.replace(BENCHMARK_DEFAULT_DB_SCHEMA, "")
@@ -374,5 +383,5 @@ class UserInputExecuteService:
bool: Returns True if write is successful, False otherwise
"""
return self.file_service.write_multi_round_benchmark_result(
output_file_path, round_id, config, inputs, outputs, start_index, offset)
output_file_path, round_id, config, inputs, outputs, start_index, offset
)

View File

@@ -29,7 +29,9 @@ class BenchmarkDataConfig(BaseModel):
cache_dir: str = "cache"
db_path: str = os.path.join(BENCHMARK_DATA_ROOT_PATH, "ant_icube_dev.db")
table_mapping_file: str = os.path.join(BENCHMARK_DATA_ROOT_PATH, "table_mapping.json")
table_mapping_file: str = os.path.join(
BENCHMARK_DATA_ROOT_PATH, "table_mapping.json"
)
cache_expiry_days: int = 1
repo_url: str = "https://github.com/inclusionAI/Falcon"
data_dir: str = "data/source"