mirror of
https://github.com/csunny/DB-GPT.git
synced 2026-01-14 12:16:38 +00:00
feat(benchmark): benchmark result file download
This commit is contained in:
@@ -5,6 +5,7 @@ from functools import cache
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, HTTPException, BackgroundTasks
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
from dbgpt.agent.core.schema import Status
|
||||
@@ -262,7 +263,7 @@ async def get_compare_run_detail(summary_id: int, limit: int = 200, offset: int
|
||||
return Result.succ(detail)
|
||||
|
||||
|
||||
@router.post("/execute_benchmark_task")
|
||||
@router.post("/execute_benchmark_task", dependencies=[Depends(check_api_key)])
|
||||
async def execute_benchmark_task(
|
||||
request: BenchmarkServeRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
@@ -296,7 +297,7 @@ async def execute_benchmark_task(
|
||||
})
|
||||
|
||||
|
||||
@router.get("/benchmark_task_list")
|
||||
@router.get("/benchmark_task_list", dependencies=[Depends(check_api_key)])
|
||||
async def benchmark_task_list(
|
||||
request: EvaluateServeRequest,
|
||||
page: Optional[int] = Query(default=1, description="current page"),
|
||||
@@ -341,6 +342,52 @@ async def get_benchmark_table_rows(table: str, limit: int = 10):
|
||||
return Result.succ({"table": table, "limit": limit, "rows": rows})
|
||||
|
||||
|
||||
@router.get("/benchmark_result_download", dependencies=[Depends(check_api_key)])
|
||||
async def download_benchmark_result(
|
||||
evaluate_code: Optional[str] = Query(default=None, description="evaluate code"),
|
||||
service: BenchmarkService = Depends(get_benchmark_service),
|
||||
):
|
||||
"""Download benchmark result file
|
||||
|
||||
Args:
|
||||
evaluate_code: The evaluation code to identify the benchmark result
|
||||
service: The benchmark service instance
|
||||
|
||||
Returns:
|
||||
StreamingResponse: File download response
|
||||
|
||||
Raises:
|
||||
HTTPException: If evaluation code is missing or file not found
|
||||
"""
|
||||
logger.info(f"download benchmark result: {evaluate_code}")
|
||||
|
||||
if not evaluate_code:
|
||||
raise HTTPException(status_code=400, detail="evaluate_code is required")
|
||||
|
||||
try:
|
||||
# 获取文件名和文件流
|
||||
file_name, file_stream = await service.get_benchmark_file_stream(evaluate_code)
|
||||
|
||||
from urllib.parse import quote
|
||||
|
||||
# 对文件名进行编码处理,支持中文和特殊字符
|
||||
encoded_filename = quote(file_name)
|
||||
|
||||
# 返回文件下载响应
|
||||
return StreamingResponse(
|
||||
content=file_stream,
|
||||
media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}; filename={encoded_filename}",
|
||||
"Content-Type": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download benchmark result for {evaluate_code}: {str(e)}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
def init_endpoints(system_app: SystemApp, config: ServeConfig) -> None:
|
||||
"""Initialize the endpoints"""
|
||||
global global_system_app
|
||||
|
||||
@@ -70,7 +70,9 @@ class BenchmarkLLMTask:
|
||||
return None
|
||||
|
||||
if response.has_text:
|
||||
return ReasoningResponse(content=self._get_answer(response.text))
|
||||
return ReasoningResponse(cot_tokens=response.usage.get("total_tokens", 0),
|
||||
think=response.thinking_text if response.has_thinking else None,
|
||||
content=self._get_answer(response.text))
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@@ -7,7 +7,8 @@ import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union, Any
|
||||
from typing import Dict, List, Optional, Union, Any, Tuple
|
||||
import io
|
||||
|
||||
from dbgpt.agent.core.schema import Status
|
||||
from dbgpt.component import ComponentType, SystemApp
|
||||
@@ -481,7 +482,7 @@ class BenchmarkService(
|
||||
except Exception as e:
|
||||
logger.error(f"Batch write error: {e}")
|
||||
|
||||
self.trigger_executor.submit(batch_write_task)
|
||||
batch_write_task()
|
||||
written_batches.add(batch_index)
|
||||
|
||||
def post_dispatch(self, i: int, config: BenchmarkExecuteConfig,
|
||||
@@ -519,3 +520,53 @@ class BenchmarkService(
|
||||
query_request = request
|
||||
return self.dao.get_list_page(query_request, page, page_size, ServeEntity.id.name)
|
||||
|
||||
async def get_benchmark_file_stream(self, evaluate_code: str) -> Tuple[str, io.BytesIO]:
|
||||
"""Get benchmark result file stream for download
|
||||
|
||||
Args:
|
||||
evaluate_code (str): The evaluation code
|
||||
|
||||
Returns:
|
||||
Tuple[str, io.BytesIO]: File name and file stream
|
||||
|
||||
Raises:
|
||||
Exception: If evaluation record not found or file not exists
|
||||
"""
|
||||
if not evaluate_code:
|
||||
raise Exception("evaluate_code is required")
|
||||
|
||||
# 1. 根据evaluate_code查询评测信息
|
||||
try:
|
||||
entity = self.dao.get_one({"evaluate_code": evaluate_code})
|
||||
if not entity:
|
||||
raise Exception(f"Evaluation record not found for code: {evaluate_code}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to query evaluation record: {e}")
|
||||
raise Exception(f"Failed to query evaluation record: {str(e)}")
|
||||
|
||||
# 2. 根据result的文件路径拿到文件
|
||||
file_path = entity.result
|
||||
if not file_path:
|
||||
raise Exception(f"No result file path found for evaluate_code: {evaluate_code}")
|
||||
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists(file_path):
|
||||
raise Exception(f"Result file not found: {file_path}")
|
||||
|
||||
try:
|
||||
# 读取文件内容到内存
|
||||
with open(file_path, 'rb') as file:
|
||||
file_content = file.read()
|
||||
|
||||
# 创建字节流
|
||||
file_stream = io.BytesIO(file_content)
|
||||
|
||||
# 获取文件名
|
||||
file_name = os.path.basename(file_path)
|
||||
|
||||
logger.info(f"Successfully prepared file stream for download: {file_name}")
|
||||
return file_name, file_stream
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read result file {file_path}: {e}")
|
||||
raise Exception(f"Failed to read result file: {str(e)}")
|
||||
@@ -62,6 +62,8 @@ class AnswerExecuteModel:
|
||||
cotTokens: Optional[Any] = None
|
||||
cost_time: Optional[int] = None
|
||||
llm_code: Optional[str] = None
|
||||
knowledge: Optional[str] = None
|
||||
prompt: Optional[str] = None
|
||||
|
||||
@staticmethod
|
||||
def from_dict(d: Dict[str, Any]) -> "AnswerExecuteModel":
|
||||
|
||||
@@ -274,6 +274,7 @@ class UserInputExecuteService:
|
||||
logger.error(
|
||||
f"[benchmark_task] queryResult error! sql = {sql}, errorMsg: {e}"
|
||||
)
|
||||
error_msg = str(e)
|
||||
logger.info(f"[benchmark_task] queryResult end! result = {execute_result}")
|
||||
|
||||
return AnswerExecuteModel(
|
||||
@@ -285,6 +286,8 @@ class UserInputExecuteService:
|
||||
cotTokens=response.cot_tokens,
|
||||
errorMsg=error_msg,
|
||||
llm_code=input.llm_code,
|
||||
knowledge=input.knowledge,
|
||||
prompt=input.prompt,
|
||||
)
|
||||
|
||||
def _extract_sql_content(self, content: str) -> str:
|
||||
|
||||
Reference in New Issue
Block a user