mirror of
https://github.com/csunny/DB-GPT.git
synced 2026-01-13 19:55:44 +00:00
Merge remote-tracking branch 'origin/feat_dataset_benchmark' into feat_dataset_benchmark
This commit is contained in:
@@ -1,11 +1,13 @@
|
||||
import json
|
||||
import logging
|
||||
import asyncio
|
||||
from functools import cache
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, Query, HTTPException, BackgroundTasks
|
||||
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
from dbgpt.agent.core.schema import Status
|
||||
from dbgpt.component import ComponentType, SystemApp
|
||||
from dbgpt.model.cluster import BaseModelController, WorkerManager, WorkerManagerFactory
|
||||
from dbgpt_serve.core import Result
|
||||
@@ -33,6 +35,37 @@ router = APIRouter()
|
||||
global_system_app: Optional[SystemApp] = None
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def _run_benchmark_task_sync(
|
||||
service: BenchmarkService,
|
||||
evaluate_code: str,
|
||||
scene_key: str,
|
||||
scene_value: str,
|
||||
input_file_path: str,
|
||||
output_file_path: str,
|
||||
model_list: List[str],
|
||||
):
|
||||
"""同步执行benchmark任务的辅助函数,用于在后台任务中运行"""
|
||||
try:
|
||||
# 创建新的事件循环来运行异步任务
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
service.run_dataset_benchmark(
|
||||
evaluate_code,
|
||||
scene_key,
|
||||
scene_value,
|
||||
input_file_path,
|
||||
output_file_path,
|
||||
model_list,
|
||||
)
|
||||
)
|
||||
logger.info(f"Benchmark task run sync finish, evaluate_code: {evaluate_code}")
|
||||
finally:
|
||||
loop.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Benchmark task failed for evaluate_code: {evaluate_code}, error: {str(e)}")
|
||||
|
||||
|
||||
def get_service() -> Service:
|
||||
"""Get the service instance"""
|
||||
@@ -232,24 +265,52 @@ async def get_compare_run_detail(summary_id: int, limit: int = 200, offset: int
|
||||
@router.post("/execute_benchmark_task")
|
||||
async def execute_benchmark_task(
|
||||
request: BenchmarkServeRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
service: BenchmarkService = Depends(get_benchmark_service),
|
||||
) -> Result:
|
||||
"""execute benchmark task
|
||||
|
||||
Args:
|
||||
request (EvaluateServeRequest): The request
|
||||
service (Service): The service
|
||||
request (BenchmarkServeRequest): The request
|
||||
background_tasks (BackgroundTasks): FastAPI background tasks
|
||||
service (BenchmarkService): The service
|
||||
Returns:
|
||||
ServerResponse: The response
|
||||
Result: The response
|
||||
"""
|
||||
# 使用FastAPI的BackgroundTasks来执行后台任务
|
||||
background_tasks.add_task(
|
||||
_run_benchmark_task_sync,
|
||||
service,
|
||||
request.evaluate_code,
|
||||
request.scene_key,
|
||||
request.scene_value,
|
||||
request.input_file_path,
|
||||
request.output_file_path,
|
||||
request.model_list,
|
||||
)
|
||||
|
||||
# 立即返回成功响应
|
||||
return Result.succ({
|
||||
"evaluate_code": request.evaluate_code,
|
||||
"status": Status.RUNNING.value
|
||||
})
|
||||
|
||||
|
||||
@router.get("/benchmark_task_list")
|
||||
async def benchmark_task_list(
|
||||
request: EvaluateServeRequest,
|
||||
page: Optional[int] = Query(default=1, description="current page"),
|
||||
page_size: Optional[int] = Query(default=20, description="page size"),
|
||||
service: BenchmarkService = Depends(get_benchmark_service),
|
||||
) -> Result:
|
||||
"""
|
||||
Query benchmark task list
|
||||
"""
|
||||
return Result.succ(
|
||||
await service.run_dataset_benchmark(
|
||||
request.evaluate_code,
|
||||
request.scene_key,
|
||||
request.scene_value,
|
||||
request.input_file_path,
|
||||
request.output_file_path,
|
||||
request.model_list,
|
||||
service.get_list_by_page(
|
||||
request,
|
||||
page,
|
||||
page_size,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -285,4 +346,4 @@ def init_endpoints(system_app: SystemApp, config: ServeConfig) -> None:
|
||||
global global_system_app
|
||||
system_app.register(Service, config=config)
|
||||
system_app.register(BenchmarkService, config=config)
|
||||
global_system_app = system_app
|
||||
global_system_app = system_app
|
||||
@@ -84,6 +84,8 @@ class BenchmarkServeRequest(BaseModel):
|
||||
state: Optional[str] = Field(None, description="evaluation state")
|
||||
temperature: Optional[str] = Field(None, description="evaluation state")
|
||||
max_tokens: Optional[str] = Field(None, description="evaluation state")
|
||||
gmt_create: Optional[str] = Field(None, description="create time")
|
||||
gmt_modified: Optional[str] = Field(None, description="create time")
|
||||
|
||||
|
||||
class StorageType(Enum):
|
||||
|
||||
@@ -16,8 +16,7 @@ from dbgpt.core import LLMClient
|
||||
from dbgpt.model import DefaultLLMClient
|
||||
from dbgpt.model.cluster import WorkerManagerFactory
|
||||
from dbgpt.storage.metadata import BaseDao
|
||||
from dbgpt.util.benchmarks import StorageUtil
|
||||
from dbgpt.util import get_or_create_event_loop
|
||||
from dbgpt.util import get_or_create_event_loop, PaginationResult
|
||||
|
||||
from ....core import BaseService
|
||||
from ....prompt.service.service import Service as PromptService
|
||||
@@ -281,6 +280,8 @@ class BenchmarkService(
|
||||
)
|
||||
result_list.extend(round_result_list)
|
||||
|
||||
logger.info(f"Benchmark task completed successfully for evaluate_code:"
|
||||
f" {evaluate_code}, output_file_path: {output_file_path}")
|
||||
return result_list
|
||||
|
||||
async def _build_benchmark_config(self, model_list, output_file_path,
|
||||
@@ -502,3 +503,19 @@ class BenchmarkService(
|
||||
output_file_path,
|
||||
)
|
||||
|
||||
def get_list_by_page(
|
||||
self, request: EvaluateServeRequest, page: int, page_size: int
|
||||
) -> PaginationResult[EvaluateServeResponse]:
|
||||
"""Get a list of Evaluate entities by page
|
||||
|
||||
Args:
|
||||
request (EvaluateServeRequest): The request
|
||||
page (int): The page number
|
||||
page_size (int): The page size
|
||||
|
||||
Returns:
|
||||
List[EvaluateServeResponse]: The response
|
||||
"""
|
||||
query_request = request
|
||||
return self.dao.get_list_page(query_request, page, page_size, ServeEntity.id.name)
|
||||
|
||||
|
||||
@@ -93,7 +93,7 @@ class FileParseService(ABC):
|
||||
]
|
||||
|
||||
# Load or create workbook and sheet
|
||||
if output_file.exists():
|
||||
if Path(output_file).exists():
|
||||
workbook = load_workbook(str(output_file))
|
||||
if "benchmark_compare_result" in workbook.sheetnames:
|
||||
worksheet = workbook["benchmark_compare_result"]
|
||||
@@ -425,7 +425,7 @@ class ExcelFileParseService(FileParseService):
|
||||
if extension.lower() not in [".xlsx", ".xls"]:
|
||||
extension = ".xlsx"
|
||||
|
||||
output_file = output_dir / f"{base_name}_round{round_id}{extension}"
|
||||
output_file = output_dir / f"{base_name}{extension}"
|
||||
|
||||
# 创建输入数据映射,便于查找
|
||||
input_map = {inp.serial_no: inp for inp in inputs}
|
||||
|
||||
Reference in New Issue
Block a user