From 5df8d94f430e18fac2d5193d023b8ca52a49be92 Mon Sep 17 00:00:00 2001 From: "alan.cl" <1165243776@qq.com> Date: Tue, 14 Oct 2025 14:43:00 +0800 Subject: [PATCH] feat(benchmark): benchmark result file download --- .../src/dbgpt_serve/evaluate/api/endpoints.py | 51 ++++++++++++++++- .../service/benchmark/benchmark_llm_task.py | 4 +- .../service/benchmark/benchmark_service.py | 55 ++++++++++++++++++- .../evaluate/service/benchmark/models.py | 2 + .../benchmark/user_input_execute_service.py | 3 + 5 files changed, 110 insertions(+), 5 deletions(-) diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/api/endpoints.py b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/api/endpoints.py index b93d0cc54..c14312000 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/api/endpoints.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/api/endpoints.py @@ -5,6 +5,7 @@ from functools import cache from typing import List, Optional from fastapi import APIRouter, Depends, Query, HTTPException, BackgroundTasks +from fastapi.responses import StreamingResponse from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer from dbgpt.agent.core.schema import Status @@ -262,7 +263,7 @@ async def get_compare_run_detail(summary_id: int, limit: int = 200, offset: int return Result.succ(detail) -@router.post("/execute_benchmark_task") +@router.post("/execute_benchmark_task", dependencies=[Depends(check_api_key)]) async def execute_benchmark_task( request: BenchmarkServeRequest, background_tasks: BackgroundTasks, @@ -296,7 +297,7 @@ async def execute_benchmark_task( }) -@router.get("/benchmark_task_list") +@router.get("/benchmark_task_list", dependencies=[Depends(check_api_key)]) async def benchmark_task_list( request: EvaluateServeRequest, page: Optional[int] = Query(default=1, description="current page"), @@ -341,6 +342,52 @@ async def get_benchmark_table_rows(table: str, limit: int = 10): return Result.succ({"table": table, "limit": limit, "rows": rows}) +@router.get("/benchmark_result_download", dependencies=[Depends(check_api_key)]) +async def download_benchmark_result( + evaluate_code: Optional[str] = Query(default=None, description="evaluate code"), + service: BenchmarkService = Depends(get_benchmark_service), +): + """Download benchmark result file + + Args: + evaluate_code: The evaluation code to identify the benchmark result + service: The benchmark service instance + + Returns: + StreamingResponse: File download response + + Raises: + HTTPException: If evaluation code is missing or file not found + """ + logger.info(f"download benchmark result: {evaluate_code}") + + if not evaluate_code: + raise HTTPException(status_code=400, detail="evaluate_code is required") + + try: + # 获取文件名和文件流 + file_name, file_stream = await service.get_benchmark_file_stream(evaluate_code) + + from urllib.parse import quote + + # 对文件名进行编码处理,支持中文和特殊字符 + encoded_filename = quote(file_name) + + # 返回文件下载响应 + return StreamingResponse( + content=file_stream, + media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + headers={ + "Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}; filename={encoded_filename}", + "Content-Type": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" + } + ) + + except Exception as e: + logger.error(f"Failed to download benchmark result for {evaluate_code}: {str(e)}") + raise HTTPException(status_code=404, detail=str(e)) + + def init_endpoints(system_app: SystemApp, config: ServeConfig) -> None: """Initialize the endpoints""" global global_system_app diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/benchmark_llm_task.py b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/benchmark_llm_task.py index e26eecbe5..dffa6d3dd 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/benchmark_llm_task.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/benchmark_llm_task.py @@ -70,7 +70,9 @@ class BenchmarkLLMTask: return None if response.has_text: - return ReasoningResponse(content=self._get_answer(response.text)) + return ReasoningResponse(cot_tokens=response.usage.get("total_tokens", 0), + think=response.thinking_text if response.has_thinking else None, + content=self._get_answer(response.text)) else: return None diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/benchmark_service.py b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/benchmark_service.py index 4538fd8a8..42c5fcdc8 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/benchmark_service.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/benchmark_service.py @@ -7,7 +7,8 @@ import uuid from concurrent.futures import ThreadPoolExecutor from datetime import datetime from pathlib import Path -from typing import Dict, List, Optional, Union, Any +from typing import Dict, List, Optional, Union, Any, Tuple +import io from dbgpt.agent.core.schema import Status from dbgpt.component import ComponentType, SystemApp @@ -481,7 +482,7 @@ class BenchmarkService( except Exception as e: logger.error(f"Batch write error: {e}") - self.trigger_executor.submit(batch_write_task) + batch_write_task() written_batches.add(batch_index) def post_dispatch(self, i: int, config: BenchmarkExecuteConfig, @@ -519,3 +520,53 @@ class BenchmarkService( query_request = request return self.dao.get_list_page(query_request, page, page_size, ServeEntity.id.name) + async def get_benchmark_file_stream(self, evaluate_code: str) -> Tuple[str, io.BytesIO]: + """Get benchmark result file stream for download + + Args: + evaluate_code (str): The evaluation code + + Returns: + Tuple[str, io.BytesIO]: File name and file stream + + Raises: + Exception: If evaluation record not found or file not exists + """ + if not evaluate_code: + raise Exception("evaluate_code is required") + + # 1. 根据evaluate_code查询评测信息 + try: + entity = self.dao.get_one({"evaluate_code": evaluate_code}) + if not entity: + raise Exception(f"Evaluation record not found for code: {evaluate_code}") + except Exception as e: + logger.error(f"Failed to query evaluation record: {e}") + raise Exception(f"Failed to query evaluation record: {str(e)}") + + # 2. 根据result的文件路径拿到文件 + file_path = entity.result + if not file_path: + raise Exception(f"No result file path found for evaluate_code: {evaluate_code}") + + # 检查文件是否存在 + if not os.path.exists(file_path): + raise Exception(f"Result file not found: {file_path}") + + try: + # 读取文件内容到内存 + with open(file_path, 'rb') as file: + file_content = file.read() + + # 创建字节流 + file_stream = io.BytesIO(file_content) + + # 获取文件名 + file_name = os.path.basename(file_path) + + logger.info(f"Successfully prepared file stream for download: {file_name}") + return file_name, file_stream + + except Exception as e: + logger.error(f"Failed to read result file {file_path}: {e}") + raise Exception(f"Failed to read result file: {str(e)}") \ No newline at end of file diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/models.py b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/models.py index e75de3f6a..e062653fc 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/models.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/models.py @@ -62,6 +62,8 @@ class AnswerExecuteModel: cotTokens: Optional[Any] = None cost_time: Optional[int] = None llm_code: Optional[str] = None + knowledge: Optional[str] = None + prompt: Optional[str] = None @staticmethod def from_dict(d: Dict[str, Any]) -> "AnswerExecuteModel": diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/user_input_execute_service.py b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/user_input_execute_service.py index 2a0a3c4eb..1567cc6ef 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/user_input_execute_service.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/user_input_execute_service.py @@ -274,6 +274,7 @@ class UserInputExecuteService: logger.error( f"[benchmark_task] queryResult error! sql = {sql}, errorMsg: {e}" ) + error_msg = str(e) logger.info(f"[benchmark_task] queryResult end! result = {execute_result}") return AnswerExecuteModel( @@ -285,6 +286,8 @@ class UserInputExecuteService: cotTokens=response.cot_tokens, errorMsg=error_msg, llm_code=input.llm_code, + knowledge=input.knowledge, + prompt=input.prompt, ) def _extract_sql_content(self, content: str) -> str: