mirror of
https://github.com/csunny/DB-GPT.git
synced 2026-01-13 19:55:44 +00:00
fix(benchmark): execute benchmark with model param
This commit is contained in:
@@ -38,12 +38,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
def _run_benchmark_task_sync(
|
||||
service: BenchmarkService,
|
||||
evaluate_code: str,
|
||||
scene_key: str,
|
||||
scene_value: str,
|
||||
input_file_path: str,
|
||||
output_file_path: str,
|
||||
model_list: List[str],
|
||||
request: BenchmarkServeRequest
|
||||
):
|
||||
"""同步执行benchmark任务的辅助函数,用于在后台任务中运行"""
|
||||
try:
|
||||
@@ -53,22 +48,24 @@ def _run_benchmark_task_sync(
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
service.run_dataset_benchmark(
|
||||
evaluate_code,
|
||||
scene_key,
|
||||
scene_value,
|
||||
input_file_path,
|
||||
output_file_path,
|
||||
model_list,
|
||||
request.evaluate_code,
|
||||
request.scene_key,
|
||||
request.scene_value,
|
||||
request.input_file_path,
|
||||
request.output_file_path,
|
||||
request.model_list,
|
||||
request.temperature,
|
||||
request.max_tokens,
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
f"Benchmark task run sync finish, evaluate_code: {evaluate_code}"
|
||||
f"Benchmark task run sync finish, request: {request}"
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Benchmark task failed for evaluate_code: {evaluate_code}, error: {str(e)}"
|
||||
f"Benchmark task failed for request: {request}, error: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@@ -270,12 +267,7 @@ async def execute_benchmark_task(
|
||||
background_tasks.add_task(
|
||||
_run_benchmark_task_sync,
|
||||
service,
|
||||
request.evaluate_code,
|
||||
request.scene_key,
|
||||
request.scene_value,
|
||||
request.input_file_path,
|
||||
request.output_file_path,
|
||||
request.model_list,
|
||||
request
|
||||
)
|
||||
|
||||
# 立即返回成功响应
|
||||
@@ -286,6 +278,7 @@ async def execute_benchmark_task(
|
||||
|
||||
@router.get("/benchmark_task_list", dependencies=[Depends(check_api_key)])
|
||||
async def benchmark_task_list(
|
||||
state: Optional[str] = Query(default=None, description="benchmark task state"),
|
||||
page: Optional[int] = Query(default=1, description="current page"),
|
||||
page_size: Optional[int] = Query(default=20, description="page size"),
|
||||
service: BenchmarkService = Depends(get_benchmark_service),
|
||||
@@ -293,9 +286,12 @@ async def benchmark_task_list(
|
||||
"""
|
||||
Query benchmark task list
|
||||
"""
|
||||
request = EvaluateServeRequest(
|
||||
state=state,
|
||||
)
|
||||
return Result.succ(
|
||||
service.get_list_by_page(
|
||||
{},
|
||||
request,
|
||||
page,
|
||||
page_size,
|
||||
)
|
||||
|
||||
@@ -82,13 +82,18 @@ class BenchmarkServeRequest(BaseModel):
|
||||
user_name: Optional[str] = Field(None, description="user name")
|
||||
user_id: Optional[str] = Field(None, description="user id")
|
||||
sys_code: Optional[str] = Field(None, description="system code")
|
||||
parallel_num: Optional[int] = Field(None, description="system code")
|
||||
parallel_num: Optional[int] = Field(None, description="task parallel num")
|
||||
state: Optional[str] = Field(None, description="evaluation state")
|
||||
temperature: Optional[str] = Field(None, description="evaluation state")
|
||||
max_tokens: Optional[str] = Field(None, description="evaluation state")
|
||||
log_info: Optional[str] = Field(None, description="evaluation log_info")
|
||||
temperature: Optional[float] = Field(
|
||||
0.7,
|
||||
description="What sampling temperature to use, between 0 and 2. Higher values "
|
||||
"like 0.8 will make the output more random, "
|
||||
"while lower values like 0.2 will "
|
||||
"make it more focused and deterministic.",)
|
||||
max_tokens: Optional[int] = Field(None, description="Max tokens")
|
||||
log_info: Optional[str] = Field(None, description="evaluation task error message")
|
||||
gmt_create: Optional[str] = Field(None, description="create time")
|
||||
gmt_modified: Optional[str] = Field(None, description="create time")
|
||||
gmt_modified: Optional[str] = Field(None, description="modified time")
|
||||
|
||||
|
||||
class BenchmarkServeResponse(BenchmarkServeRequest):
|
||||
|
||||
@@ -58,7 +58,7 @@ BENCHMARK_SERVICE_COMPONENT_NAME = "dbgpt_serve_evaluate_benchmark_service"
|
||||
# TODO 需要修改为正式文件
|
||||
STANDARD_BENCHMARK_FILE_PATH = os.path.join(
|
||||
BENCHMARK_DATA_ROOT_PATH,
|
||||
"2025_07_27_public_500_standard_benchmark_question_list_multi_anwser.xlsx",
|
||||
"2025_07_27_public_500_standard_benchmark_question_list_v2.xlsx",
|
||||
)
|
||||
|
||||
BENCHMARK_OUTPUT_RESULT_PATH = os.path.join(BENCHMARK_DATA_ROOT_PATH, "result")
|
||||
@@ -220,6 +220,8 @@ class BenchmarkService(
|
||||
input_file_path: str,
|
||||
output_file_path: str,
|
||||
model_list: List[str],
|
||||
temperature: Optional[float],
|
||||
max_tokens: Optional[int],
|
||||
) -> List[BenchmarkTaskResult[OutputType]]:
|
||||
"""
|
||||
Run the dataset benchmark
|
||||
|
||||
Reference in New Issue
Block a user