Merge remote-tracking branch 'origin/feat_dataset_benchmark' into feat_dataset_benchmark

This commit is contained in:
yaoyifan-yyf
2025-10-13 16:59:18 +08:00
4 changed files with 96 additions and 16 deletions

View File

@@ -1,11 +1,13 @@
import json
import logging
import asyncio
from functools import cache
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, Query, HTTPException, BackgroundTasks
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
from dbgpt.agent.core.schema import Status
from dbgpt.component import ComponentType, SystemApp
from dbgpt.model.cluster import BaseModelController, WorkerManager, WorkerManagerFactory
from dbgpt_serve.core import Result
@@ -33,6 +35,37 @@ router = APIRouter()
global_system_app: Optional[SystemApp] = None
logger = logging.getLogger(__name__)
def _run_benchmark_task_sync(
service: BenchmarkService,
evaluate_code: str,
scene_key: str,
scene_value: str,
input_file_path: str,
output_file_path: str,
model_list: List[str],
):
"""同步执行benchmark任务的辅助函数用于在后台任务中运行"""
try:
# 创建新的事件循环来运行异步任务
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
service.run_dataset_benchmark(
evaluate_code,
scene_key,
scene_value,
input_file_path,
output_file_path,
model_list,
)
)
logger.info(f"Benchmark task run sync finish, evaluate_code: {evaluate_code}")
finally:
loop.close()
except Exception as e:
logger.error(f"Benchmark task failed for evaluate_code: {evaluate_code}, error: {str(e)}")
def get_service() -> Service:
"""Get the service instance"""
@@ -232,24 +265,52 @@ async def get_compare_run_detail(summary_id: int, limit: int = 200, offset: int
@router.post("/execute_benchmark_task")
async def execute_benchmark_task(
request: BenchmarkServeRequest,
background_tasks: BackgroundTasks,
service: BenchmarkService = Depends(get_benchmark_service),
) -> Result:
"""execute benchmark task
Args:
request (EvaluateServeRequest): The request
service (Service): The service
request (BenchmarkServeRequest): The request
background_tasks (BackgroundTasks): FastAPI background tasks
service (BenchmarkService): The service
Returns:
ServerResponse: The response
Result: The response
"""
# 使用FastAPI的BackgroundTasks来执行后台任务
background_tasks.add_task(
_run_benchmark_task_sync,
service,
request.evaluate_code,
request.scene_key,
request.scene_value,
request.input_file_path,
request.output_file_path,
request.model_list,
)
# 立即返回成功响应
return Result.succ({
"evaluate_code": request.evaluate_code,
"status": Status.RUNNING.value
})
@router.get("/benchmark_task_list")
async def benchmark_task_list(
request: EvaluateServeRequest,
page: Optional[int] = Query(default=1, description="current page"),
page_size: Optional[int] = Query(default=20, description="page size"),
service: BenchmarkService = Depends(get_benchmark_service),
) -> Result:
"""
Query benchmark task list
"""
return Result.succ(
await service.run_dataset_benchmark(
request.evaluate_code,
request.scene_key,
request.scene_value,
request.input_file_path,
request.output_file_path,
request.model_list,
service.get_list_by_page(
request,
page,
page_size,
)
)
@@ -285,4 +346,4 @@ def init_endpoints(system_app: SystemApp, config: ServeConfig) -> None:
global global_system_app
system_app.register(Service, config=config)
system_app.register(BenchmarkService, config=config)
global_system_app = system_app
global_system_app = system_app

View File

@@ -84,6 +84,8 @@ class BenchmarkServeRequest(BaseModel):
state: Optional[str] = Field(None, description="evaluation state")
temperature: Optional[str] = Field(None, description="evaluation state")
max_tokens: Optional[str] = Field(None, description="evaluation state")
gmt_create: Optional[str] = Field(None, description="create time")
gmt_modified: Optional[str] = Field(None, description="create time")
class StorageType(Enum):

View File

@@ -16,8 +16,7 @@ from dbgpt.core import LLMClient
from dbgpt.model import DefaultLLMClient
from dbgpt.model.cluster import WorkerManagerFactory
from dbgpt.storage.metadata import BaseDao
from dbgpt.util.benchmarks import StorageUtil
from dbgpt.util import get_or_create_event_loop
from dbgpt.util import get_or_create_event_loop, PaginationResult
from ....core import BaseService
from ....prompt.service.service import Service as PromptService
@@ -281,6 +280,8 @@ class BenchmarkService(
)
result_list.extend(round_result_list)
logger.info(f"Benchmark task completed successfully for evaluate_code:"
f" {evaluate_code}, output_file_path: {output_file_path}")
return result_list
async def _build_benchmark_config(self, model_list, output_file_path,
@@ -502,3 +503,19 @@ class BenchmarkService(
output_file_path,
)
def get_list_by_page(
self, request: EvaluateServeRequest, page: int, page_size: int
) -> PaginationResult[EvaluateServeResponse]:
"""Get a list of Evaluate entities by page
Args:
request (EvaluateServeRequest): The request
page (int): The page number
page_size (int): The page size
Returns:
List[EvaluateServeResponse]: The response
"""
query_request = request
return self.dao.get_list_page(query_request, page, page_size, ServeEntity.id.name)

View File

@@ -93,7 +93,7 @@ class FileParseService(ABC):
]
# Load or create workbook and sheet
if output_file.exists():
if Path(output_file).exists():
workbook = load_workbook(str(output_file))
if "benchmark_compare_result" in workbook.sheetnames:
worksheet = workbook["benchmark_compare_result"]
@@ -425,7 +425,7 @@ class ExcelFileParseService(FileParseService):
if extension.lower() not in [".xlsx", ".xls"]:
extension = ".xlsx"
output_file = output_dir / f"{base_name}_round{round_id}{extension}"
output_file = output_dir / f"{base_name}{extension}"
# 创建输入数据映射,便于查找
input_map = {inp.serial_no: inp for inp in inputs}