mirror of
https://github.com/csunny/DB-GPT.git
synced 2026-01-13 19:55:44 +00:00
opt: multi model compare write result
This commit is contained in:
@@ -75,11 +75,7 @@ class FileParseService(ABC):
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Determine final excel file path: <base>_round{round_id}.xlsx
|
||||
base_name = Path(path).stem
|
||||
extension = Path(path).suffix
|
||||
if extension.lower() not in [".xlsx", ".xls"]:
|
||||
extension = ".xlsx"
|
||||
output_file = output_dir / f"{base_name}_round{round_id}{extension}"
|
||||
output_file = path
|
||||
|
||||
headers = [
|
||||
"serialNo",
|
||||
@@ -88,6 +84,7 @@ class FileParseService(ABC):
|
||||
"selfDefineTags",
|
||||
"prompt",
|
||||
"standardAnswerSql",
|
||||
"llmCode",
|
||||
"llmOutput",
|
||||
"executeResult",
|
||||
"errorMsg",
|
||||
@@ -124,6 +121,7 @@ class FileParseService(ABC):
|
||||
cm.selfDefineTags,
|
||||
cm.prompt,
|
||||
cm.standardAnswerSql,
|
||||
cm.llmCode,
|
||||
cm.llmOutput,
|
||||
json.dumps(cm.executeResult, ensure_ascii=False)
|
||||
if cm.executeResult is not None
|
||||
|
||||
@@ -122,6 +122,7 @@ class RoundAnswerConfirmModel:
|
||||
executeResult: Optional[Dict[str, List[str]]] = None
|
||||
errorMsg: Optional[str] = None
|
||||
compareResult: Optional[DataCompareResultEnum] = None
|
||||
llmCode: Optional[str] = None
|
||||
|
||||
|
||||
class FileParseTypeEnum(Enum):
|
||||
|
||||
@@ -90,7 +90,7 @@ class UserInputExecuteService:
|
||||
)
|
||||
self._execute_llm_compare_result(
|
||||
output_file_path,
|
||||
1,
|
||||
round_id,
|
||||
inputs,
|
||||
standard_sets,
|
||||
right_outputs,
|
||||
@@ -109,14 +109,21 @@ class UserInputExecuteService:
|
||||
config: BenchmarkExecuteConfig,
|
||||
):
|
||||
left_map = {a.serialNo: a for a in left_answers}
|
||||
right_map = {a.serialNo: a for a in right_answers}
|
||||
# group right answers by serialNo to support multiple models per input
|
||||
right_group_map: Dict[int, List[AnswerExecuteModel]] = {}
|
||||
for a in right_answers:
|
||||
right_group_map.setdefault(a.serialNo, []).append(a)
|
||||
confirm_list: List[RoundAnswerConfirmModel] = []
|
||||
|
||||
# compute unique llm_count across all right answers
|
||||
llm_codes = set([a.llm_code for a in right_answers if getattr(a, "llm_code", None)])
|
||||
llm_count = len(llm_codes) if llm_codes else len(right_answers)
|
||||
|
||||
for inp in inputs:
|
||||
left = left_map.get(inp.serial_no)
|
||||
right = right_map.get(inp.serial_no)
|
||||
rights = right_group_map.get(inp.serial_no, [])
|
||||
|
||||
if left is None and right is None:
|
||||
if left is None and not rights:
|
||||
continue
|
||||
|
||||
strategy_cfg = None
|
||||
@@ -132,12 +139,11 @@ class UserInputExecuteService:
|
||||
strategy_cfg = DataCompareStrategyConfig(
|
||||
strategy="EXACT_MATCH",
|
||||
order_by=True,
|
||||
standard_result=standard_result_list
|
||||
if standard_result_list
|
||||
else None,
|
||||
standard_result=standard_result_list if standard_result_list else None,
|
||||
)
|
||||
|
||||
if right is not None:
|
||||
# for each right answer (per model)
|
||||
for right in rights:
|
||||
if config.compare_config and isinstance(config.compare_config, dict):
|
||||
res = self.compare_service.compare_json_by_config(
|
||||
left.llmOutput if left else "",
|
||||
@@ -150,9 +156,7 @@ class UserInputExecuteService:
|
||||
compare_result = DataCompareResultEnum.FAILED
|
||||
else:
|
||||
res = self.compare_service.compare(
|
||||
left
|
||||
if left
|
||||
else AnswerExecuteModel(
|
||||
left if left else AnswerExecuteModel(
|
||||
serialNo=inp.serial_no,
|
||||
analysisModelId=inp.analysis_model_id,
|
||||
question=inp.question,
|
||||
@@ -174,6 +178,7 @@ class UserInputExecuteService:
|
||||
executeResult=right.executeResult if right else None,
|
||||
errorMsg=right.errorMsg if right else None,
|
||||
compareResult=compare_result,
|
||||
llmCode=right.llm_code,
|
||||
)
|
||||
confirm_list.append(confirm)
|
||||
|
||||
@@ -182,7 +187,7 @@ class UserInputExecuteService:
|
||||
round_id,
|
||||
confirm_list,
|
||||
config.benchmark_mode_type == BenchmarkModeTypeEnum.EXECUTE,
|
||||
2,
|
||||
llm_count,
|
||||
)
|
||||
|
||||
def _convert_query_result_to_column_format(
|
||||
|
||||
Reference in New Issue
Block a user