feat(benchmark): benchmark result file download

This commit is contained in:
alan.cl
2025-10-14 14:43:00 +08:00
parent 838bc359ad
commit 5df8d94f43
5 changed files with 110 additions and 5 deletions

View File

@@ -5,6 +5,7 @@ from functools import cache
from typing import List, Optional
from fastapi import APIRouter, Depends, Query, HTTPException, BackgroundTasks
from fastapi.responses import StreamingResponse
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
from dbgpt.agent.core.schema import Status
@@ -262,7 +263,7 @@ async def get_compare_run_detail(summary_id: int, limit: int = 200, offset: int
return Result.succ(detail)
@router.post("/execute_benchmark_task")
@router.post("/execute_benchmark_task", dependencies=[Depends(check_api_key)])
async def execute_benchmark_task(
request: BenchmarkServeRequest,
background_tasks: BackgroundTasks,
@@ -296,7 +297,7 @@ async def execute_benchmark_task(
})
@router.get("/benchmark_task_list")
@router.get("/benchmark_task_list", dependencies=[Depends(check_api_key)])
async def benchmark_task_list(
request: EvaluateServeRequest,
page: Optional[int] = Query(default=1, description="current page"),
@@ -341,6 +342,52 @@ async def get_benchmark_table_rows(table: str, limit: int = 10):
return Result.succ({"table": table, "limit": limit, "rows": rows})
@router.get("/benchmark_result_download", dependencies=[Depends(check_api_key)])
async def download_benchmark_result(
evaluate_code: Optional[str] = Query(default=None, description="evaluate code"),
service: BenchmarkService = Depends(get_benchmark_service),
):
"""Download benchmark result file
Args:
evaluate_code: The evaluation code to identify the benchmark result
service: The benchmark service instance
Returns:
StreamingResponse: File download response
Raises:
HTTPException: If evaluation code is missing or file not found
"""
logger.info(f"download benchmark result: {evaluate_code}")
if not evaluate_code:
raise HTTPException(status_code=400, detail="evaluate_code is required")
try:
# 获取文件名和文件流
file_name, file_stream = await service.get_benchmark_file_stream(evaluate_code)
from urllib.parse import quote
# 对文件名进行编码处理,支持中文和特殊字符
encoded_filename = quote(file_name)
# 返回文件下载响应
return StreamingResponse(
content=file_stream,
media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
headers={
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}; filename={encoded_filename}",
"Content-Type": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
}
)
except Exception as e:
logger.error(f"Failed to download benchmark result for {evaluate_code}: {str(e)}")
raise HTTPException(status_code=404, detail=str(e))
def init_endpoints(system_app: SystemApp, config: ServeConfig) -> None:
"""Initialize the endpoints"""
global global_system_app

View File

@@ -70,7 +70,9 @@ class BenchmarkLLMTask:
return None
if response.has_text:
return ReasoningResponse(content=self._get_answer(response.text))
return ReasoningResponse(cot_tokens=response.usage.get("total_tokens", 0),
think=response.thinking_text if response.has_thinking else None,
content=self._get_answer(response.text))
else:
return None

View File

@@ -7,7 +7,8 @@ import uuid
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Union, Any
from typing import Dict, List, Optional, Union, Any, Tuple
import io
from dbgpt.agent.core.schema import Status
from dbgpt.component import ComponentType, SystemApp
@@ -481,7 +482,7 @@ class BenchmarkService(
except Exception as e:
logger.error(f"Batch write error: {e}")
self.trigger_executor.submit(batch_write_task)
batch_write_task()
written_batches.add(batch_index)
def post_dispatch(self, i: int, config: BenchmarkExecuteConfig,
@@ -519,3 +520,53 @@ class BenchmarkService(
query_request = request
return self.dao.get_list_page(query_request, page, page_size, ServeEntity.id.name)
async def get_benchmark_file_stream(self, evaluate_code: str) -> Tuple[str, io.BytesIO]:
"""Get benchmark result file stream for download
Args:
evaluate_code (str): The evaluation code
Returns:
Tuple[str, io.BytesIO]: File name and file stream
Raises:
Exception: If evaluation record not found or file not exists
"""
if not evaluate_code:
raise Exception("evaluate_code is required")
# 1. 根据evaluate_code查询评测信息
try:
entity = self.dao.get_one({"evaluate_code": evaluate_code})
if not entity:
raise Exception(f"Evaluation record not found for code: {evaluate_code}")
except Exception as e:
logger.error(f"Failed to query evaluation record: {e}")
raise Exception(f"Failed to query evaluation record: {str(e)}")
# 2. 根据result的文件路径拿到文件
file_path = entity.result
if not file_path:
raise Exception(f"No result file path found for evaluate_code: {evaluate_code}")
# 检查文件是否存在
if not os.path.exists(file_path):
raise Exception(f"Result file not found: {file_path}")
try:
# 读取文件内容到内存
with open(file_path, 'rb') as file:
file_content = file.read()
# 创建字节流
file_stream = io.BytesIO(file_content)
# 获取文件名
file_name = os.path.basename(file_path)
logger.info(f"Successfully prepared file stream for download: {file_name}")
return file_name, file_stream
except Exception as e:
logger.error(f"Failed to read result file {file_path}: {e}")
raise Exception(f"Failed to read result file: {str(e)}")

View File

@@ -62,6 +62,8 @@ class AnswerExecuteModel:
cotTokens: Optional[Any] = None
cost_time: Optional[int] = None
llm_code: Optional[str] = None
knowledge: Optional[str] = None
prompt: Optional[str] = None
@staticmethod
def from_dict(d: Dict[str, Any]) -> "AnswerExecuteModel":

View File

@@ -274,6 +274,7 @@ class UserInputExecuteService:
logger.error(
f"[benchmark_task] queryResult error! sql = {sql}, errorMsg: {e}"
)
error_msg = str(e)
logger.info(f"[benchmark_task] queryResult end! result = {execute_result}")
return AnswerExecuteModel(
@@ -285,6 +286,8 @@ class UserInputExecuteService:
cotTokens=response.cot_tokens,
errorMsg=error_msg,
llm_code=input.llm_code,
knowledge=input.knowledge,
prompt=input.prompt,
)
def _extract_sql_content(self, content: str) -> str: