From 8d8d45531007a18b98ae2bb020e5dc1dcbefbdc1 Mon Sep 17 00:00:00 2001 From: yaoyifan-yyf Date: Mon, 13 Oct 2025 15:55:42 +0800 Subject: [PATCH] opt: multi model compare write result --- .../service/benchmark/file_parse_service.py | 8 ++--- .../evaluate/service/benchmark/models.py | 1 + .../benchmark/user_input_execute_service.py | 29 +++++++++++-------- 3 files changed, 21 insertions(+), 17 deletions(-) diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/file_parse_service.py b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/file_parse_service.py index 4d0635af8..88b004ae1 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/file_parse_service.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/file_parse_service.py @@ -75,11 +75,7 @@ class FileParseService(ABC): output_dir.mkdir(parents=True, exist_ok=True) # Determine final excel file path: _round{round_id}.xlsx - base_name = Path(path).stem - extension = Path(path).suffix - if extension.lower() not in [".xlsx", ".xls"]: - extension = ".xlsx" - output_file = output_dir / f"{base_name}_round{round_id}{extension}" + output_file = path headers = [ "serialNo", @@ -88,6 +84,7 @@ class FileParseService(ABC): "selfDefineTags", "prompt", "standardAnswerSql", + "llmCode", "llmOutput", "executeResult", "errorMsg", @@ -124,6 +121,7 @@ class FileParseService(ABC): cm.selfDefineTags, cm.prompt, cm.standardAnswerSql, + cm.llmCode, cm.llmOutput, json.dumps(cm.executeResult, ensure_ascii=False) if cm.executeResult is not None diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/models.py b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/models.py index e9a52603b..a4308edb4 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/models.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/models.py @@ -122,6 +122,7 @@ class RoundAnswerConfirmModel: executeResult: Optional[Dict[str, List[str]]] = None errorMsg: Optional[str] = None compareResult: Optional[DataCompareResultEnum] = None + llmCode: Optional[str] = None class FileParseTypeEnum(Enum): diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/user_input_execute_service.py b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/user_input_execute_service.py index 16f426fe6..72934e567 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/user_input_execute_service.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/user_input_execute_service.py @@ -90,7 +90,7 @@ class UserInputExecuteService: ) self._execute_llm_compare_result( output_file_path, - 1, + round_id, inputs, standard_sets, right_outputs, @@ -109,14 +109,21 @@ class UserInputExecuteService: config: BenchmarkExecuteConfig, ): left_map = {a.serialNo: a for a in left_answers} - right_map = {a.serialNo: a for a in right_answers} + # group right answers by serialNo to support multiple models per input + right_group_map: Dict[int, List[AnswerExecuteModel]] = {} + for a in right_answers: + right_group_map.setdefault(a.serialNo, []).append(a) confirm_list: List[RoundAnswerConfirmModel] = [] + # compute unique llm_count across all right answers + llm_codes = set([a.llm_code for a in right_answers if getattr(a, "llm_code", None)]) + llm_count = len(llm_codes) if llm_codes else len(right_answers) + for inp in inputs: left = left_map.get(inp.serial_no) - right = right_map.get(inp.serial_no) + rights = right_group_map.get(inp.serial_no, []) - if left is None and right is None: + if left is None and not rights: continue strategy_cfg = None @@ -132,12 +139,11 @@ class UserInputExecuteService: strategy_cfg = DataCompareStrategyConfig( strategy="EXACT_MATCH", order_by=True, - standard_result=standard_result_list - if standard_result_list - else None, + standard_result=standard_result_list if standard_result_list else None, ) - if right is not None: + # for each right answer (per model) + for right in rights: if config.compare_config and isinstance(config.compare_config, dict): res = self.compare_service.compare_json_by_config( left.llmOutput if left else "", @@ -150,9 +156,7 @@ class UserInputExecuteService: compare_result = DataCompareResultEnum.FAILED else: res = self.compare_service.compare( - left - if left - else AnswerExecuteModel( + left if left else AnswerExecuteModel( serialNo=inp.serial_no, analysisModelId=inp.analysis_model_id, question=inp.question, @@ -174,6 +178,7 @@ class UserInputExecuteService: executeResult=right.executeResult if right else None, errorMsg=right.errorMsg if right else None, compareResult=compare_result, + llmCode=right.llm_code, ) confirm_list.append(confirm) @@ -182,7 +187,7 @@ class UserInputExecuteService: round_id, confirm_list, config.benchmark_mode_type == BenchmarkModeTypeEnum.EXECUTE, - 2, + llm_count, ) def _convert_query_result_to_column_format(