feat: benchmark post_dispatch service

This commit is contained in:
yaoyifan-yyf
2025-09-26 17:47:11 +08:00
parent bc14084826
commit 4de41d1a25
18 changed files with 590 additions and 2 deletions

View File

@@ -0,0 +1,4 @@
{"serialNo":1,"analysisModelId":"D2025050900161503000025249569","question":"各性别的平均年龄是多少,并按年龄顺序显示结果?","selfDefineTags":"KAGGLE_DS_1,CTE1","prompt":"...","knowledge":""}
{"serialNo":2,"analysisModelId":"D2025050900161503000025249569","question":"不同投资目标下政府债券的总量是多少,并按目标名称排序?","selfDefineTags":"KAGGLE_DS_1,CTE1","prompt":"...","knowledge":""}
{"serialNo":3,"analysisModelId":"D2025050900161503000025249569","question":"用于触发双模型结果数不相等的case","selfDefineTags":"TEST","prompt":"...","knowledge":""}
{"serialNo":4,"analysisModelId":"D2025050900161503000025249569","question":"用于JSON对比策略的case","selfDefineTags":"TEST_JSON","prompt":"...","knowledge":""}

View File

@@ -0,0 +1,5 @@
{"serialNo":1,"analysisModelId":"D2025050900161503000025249569","question":"各性别的平均年龄是多少,并按年龄顺序显示结果?","llmOutput":"with converted_data as (...)\nselect gender as `性别`, avg(age) as `平均年龄` from converted_data group by gender order by `平均年龄`;","executeResult":{"性别":["Female","Male"],"平均年龄":["27.73","27.84"]},"errorMsg":null}
{"serialNo":2,"analysisModelId":"D2025050900161503000025249569","question":"不同投资目标下政府债券的总量是多少,并按目标名称排序?","llmOutput":"with gov_bonds_data as (...)\nselect objective as `objective`, sum(gov_bond_value) as `政府债券总量` from gov_bonds_data group by `objective` order by `objective`;","executeResult":{"objective":["Capital Appreciation","Growth","Income"],"政府债券总量":["117","54","15"]},"errorMsg":null}
{"serialNo":3,"analysisModelId":"D2025050900161503000025249569","question":"用于触发双模型结果数不相等的case","llmOutput":"select 1","executeResult":{"colA":["x","y"]},"errorMsg":null}
{"serialNo":4,"analysisModelId":"D2025050900161503000025249569","question":"用于JSON对比策略的case","llmOutput":"{\"check\":\"ok\"}","executeResult":null,"errorMsg":null}
{"serialNo":5,"analysisModelId":"D2025050900161503000025249569","question":"缺少匹配标准的case","llmOutput":"select * from t","executeResult":null,"errorMsg":"execution error"}

View File

@@ -0,0 +1,4 @@
{"serialNo": 1, "analysisModelId": "D2025050900161503000025249569", "question": "各性别的平均年龄是多少,并按年龄顺序显示结果?", "selfDefineTags": "KAGGLE_DS_1,CTE1", "prompt": "...", "standardAnswerSql": "with converted_data as (\n select \n gender,\n cast(age as int) as age\n from \n ant_icube_dev.di_finance_data\n where \n age rlike '^[0-9]+$'\n)\nselect\n gender as `性别`,\n avg(age) as `平均年龄`\nfrom \n converted_data\ngroup by \n gender\norder by \n `平均年龄`;", "llmOutput": "with converted_data as (...)\nselect gender as `性别`, avg(age) as `平均年龄` from converted_data group by gender order by `平均年龄`;", "executeResult": {"性别": ["Female", "Male"], "平均年龄": ["27.73", "27.84"]}, "errorMsg": null, "compareResult": "RIGHT", "isExecute": true, "llmCount": 2}
{"serialNo": 2, "analysisModelId": "D2025050900161503000025249569", "question": "不同投资目标下政府债券的总量是多少,并按目标名称排序?", "selfDefineTags": "KAGGLE_DS_1,CTE1", "prompt": "...", "standardAnswerSql": "with gov_bonds_data as (\n select\n objective,\n cast(government_bonds as bigint) as gov_bond_value\n from\n ant_icube_dev.di_finance_data\n where\n government_bonds is not null\n and government_bonds rlike '^[0-9]+$'\n)\nselect\n objective as `objective`,\n sum(gov_bond_value) as `政府债券总量`\nfrom\n gov_bonds_data\ngroup by\n `objective`\norder by\n `objective`;", "llmOutput": "with gov_bonds_data as (...)\nselect objective as `objective`, sum(gov_bond_value) as `政府债券总量` from gov_bonds_data group by `objective` order by `objective`;", "executeResult": {"objective": ["Capital Appreciation", "Growth", "Income"], "政府债券总量": ["117", "54", "15"]}, "errorMsg": null, "compareResult": "RIGHT", "isExecute": true, "llmCount": 2}
{"serialNo": 3, "analysisModelId": "D2025050900161503000025249569", "question": "用于触发双模型结果数不相等的case", "selfDefineTags": "TEST", "prompt": "...", "standardAnswerSql": null, "llmOutput": "select 1", "executeResult": {"colA": ["x", "y"]}, "errorMsg": null, "compareResult": "FAILED", "isExecute": true, "llmCount": 2}
{"serialNo": 4, "analysisModelId": "D2025050900161503000025249569", "question": "用于JSON对比策略的case", "selfDefineTags": "TEST_JSON", "prompt": "...", "standardAnswerSql": null, "llmOutput": "{\"check\":\"ok\"}", "executeResult": null, "errorMsg": null, "compareResult": "FAILED", "isExecute": true, "llmCount": 2}

View File

@@ -0,0 +1,6 @@
{
"right": 2,
"wrong": 0,
"failed": 2,
"exception": 0
}

View File

@@ -0,0 +1,4 @@
{"serialNo":1,"analysisModelId":"D2025050900161503000025249569","question":"各性别的平均年龄是多少,并按年龄顺序显示结果?","llmOutput":"with converted_data as (...)\nselect gender as `性别`, avg(age) as `平均年龄` from converted_data group by gender order by `平均年龄`;","executeResult":{"性别":["Female","Male"],"平均年龄":["27.73","27.84"]},"errorMsg":null}
{"serialNo":2,"analysisModelId":"D2025050900161503000025249569","question":"不同投资目标下政府债券的总量是多少,并按目标名称排序?","llmOutput":"with gov_bonds_data as (...)\nselect objective as `objective`, sum(gov_bond_value) as `政府债券总量` from gov_bonds_data group by `objective` order by `objective`;","executeResult":{"objective":["Capital Appreciation","Growth","Income"],"政府债券总量":["117","54","15"]},"errorMsg":null}
{"serialNo":3,"analysisModelId":"D2025050900161503000025249569","question":"用于触发双模型结果数不相等的case","llmOutput":"select 1","executeResult":{"colA":["x","y"]},"errorMsg":null}
{"serialNo":4,"analysisModelId":"D2025050900161503000025249569","question":"用于JSON对比策略的case","llmOutput":"{\"check\":\"ok\"}","executeResult":null,"errorMsg":null}

View File

@@ -0,0 +1,4 @@
{"serialNo":1,"analysisModelId":"D2025050900161503000025249569","question":"各性别的平均年龄是多少,并按年龄顺序显示结果?","llmOutput":"with converted_data as (...)\nselect gender as `性别`, avg(age) as `平均年龄` from converted_data group by gender order by `平均年龄`;","executeResult":{"性别":["Female","Male"],"平均年龄":["27.73","27.84"]},"errorMsg":null}
{"serialNo":2,"analysisModelId":"D2025050900161503000025249569","question":"不同投资目标下政府债券的总量是多少,并按目标名称排序?","llmOutput":"with gov_bonds_data as (...)\nselect objective as `objective`, sum(gov_bond_value) as `政府债券总量` from gov_bonds_data group by `objective` order by `objective`;","executeResult":{"objective":["Capital Appreciation","Growth","Income"],"政府债券总量":["117","54","15"]},"errorMsg":null}
{"serialNo":3,"analysisModelId":"D2025050900161503000025249569","question":"用于触发双模型结果数不相等的case","llmOutput":"select 1","executeResult":{"colB":["x","z","w"]},"errorMsg":null}
{"serialNo":4,"analysisModelId":"D2025050900161503000025249569","question":"用于JSON对比策略的case","llmOutput":"{\"check\":\"ok\"}","executeResult":null,"errorMsg":null}

View File

@@ -0,0 +1,4 @@
{"serialNo": 1, "analysisModelId": "D2025050900161503000025249569", "question": "各性别的平均年龄是多少,并按年龄顺序显示结果?", "selfDefineTags": "KAGGLE_DS_1,CTE1", "prompt": "...", "standardAnswerSql": "with converted_data as (...)\nselect gender as `性别`, avg(age) as `平均年龄` from converted_data group by gender order by `平均年龄`;", "llmOutput": "with converted_data as (...)\nselect gender as `性别`, avg(age) as `平均年龄` from converted_data group by gender order by `平均年龄`;", "executeResult": {"性别": ["Female", "Male"], "平均年龄": ["27.73", "27.84"]}, "errorMsg": null, "compareResult": "EXCEPTION", "isExecute": false, "llmCount": 2}
{"serialNo": 2, "analysisModelId": "D2025050900161503000025249569", "question": "不同投资目标下政府债券的总量是多少,并按目标名称排序?", "selfDefineTags": "KAGGLE_DS_1,CTE1", "prompt": "...", "standardAnswerSql": "with gov_bonds_data as (...)\nselect objective as `objective`, sum(gov_bond_value) as `政府债券总量` from gov_bonds_data group by `objective` order by `objective`;", "llmOutput": "with gov_bonds_data as (...)\nselect objective as `objective`, sum(gov_bond_value) as `政府债券总量` from gov_bonds_data group by `objective` order by `objective`;", "executeResult": {"objective": ["Capital Appreciation", "Growth", "Income"], "政府债券总量": ["117", "54", "15"]}, "errorMsg": null, "compareResult": "EXCEPTION", "isExecute": false, "llmCount": 2}
{"serialNo": 3, "analysisModelId": "D2025050900161503000025249569", "question": "用于触发双模型结果数不相等的case", "selfDefineTags": "TEST", "prompt": "...", "standardAnswerSql": "select 1", "llmOutput": "select 1", "executeResult": {"colB": ["x", "z", "w"]}, "errorMsg": null, "compareResult": "EXCEPTION", "isExecute": false, "llmCount": 2}
{"serialNo": 4, "analysisModelId": "D2025050900161503000025249569", "question": "用于JSON对比策略的case", "selfDefineTags": "TEST_JSON", "prompt": "...", "standardAnswerSql": "{\"check\":\"ok\"}", "llmOutput": "{\"check\":\"ok\"}", "executeResult": null, "errorMsg": null, "compareResult": "RIGHT", "isExecute": false, "llmCount": 2}

View File

@@ -0,0 +1,6 @@
{
"right": 1,
"wrong": 0,
"failed": 0,
"exception": 3
}

View File

@@ -0,0 +1,144 @@
from typing import Dict, List, Optional
from models import DataCompareResult, DataCompareResultEnum, DataCompareStrategyConfig, AnswerExecuteModel
from copy import deepcopy
import hashlib
import json
from decimal import Decimal, ROUND_HALF_UP
def md5_list(values: List[str]) -> str:
s = ",".join([v if v is not None else "" for v in values])
return hashlib.md5(s.encode("utf-8")).hexdigest()
def accurate_decimal(table: Dict[str, List[str]], scale: int = 2) -> Dict[str, List[str]]:
out = {}
for k, col in table.items():
new_col = []
for v in col:
if v is None:
new_col.append("")
continue
vs = str(v)
try:
d = Decimal(vs)
new_col.append(str(d.quantize(Decimal("1." + "0"*scale), rounding=ROUND_HALF_UP)))
except:
new_col.append(vs)
out[k] = new_col
return out
def sort_columns_by_key(table: Dict[str, List[str]], sort_key: str) -> Dict[str, List[str]]:
if sort_key not in table:
raise ValueError(f"base col not exist: {sort_key}")
base = table[sort_key]
row_count = len(base)
for k, col in table.items():
if len(col) != row_count:
raise ValueError(f"col length diff: {k}")
indices = list(range(row_count))
indices.sort(key=lambda i: "" if base[i] is None else str(base[i]))
sorted_table = {}
for k in table.keys():
sorted_table[k] = [table[k][i] for i in indices]
return sorted_table
class DataCompareService:
def compare(self, standard_model: AnswerExecuteModel, target_result: Optional[Dict[str, List[str]]]) -> DataCompareResult:
if target_result is None:
return DataCompareResult.failed("targetResult is null")
cfg: DataCompareStrategyConfig = standard_model.strategyConfig or DataCompareStrategyConfig(strategy="EXACT_MATCH", order_by=True, standard_result=None)
if not cfg.standard_result:
return DataCompareResult.failed("leftResult is null")
for std in cfg.standard_result:
if not isinstance(std, dict):
continue
std_fmt = accurate_decimal(deepcopy(std), 2)
tgt_fmt = accurate_decimal(deepcopy(target_result), 2)
if cfg.order_by:
res = self._compare_ordered(std_fmt, cfg, tgt_fmt)
else:
res = self._compare_unordered(std_fmt, cfg, tgt_fmt)
if res.compare_result == DataCompareResultEnum.RIGHT:
return res
return DataCompareResult.wrong("compareResult wrong!")
def _compare_ordered(self, std: Dict[str, List[str]], cfg: DataCompareStrategyConfig, tgt: Dict[str, List[str]]) -> DataCompareResult:
try:
std_md5 = set()
for col_vals in std.values():
lst = ["" if v is None else str(v) for v in col_vals]
std_md5.add(md5_list(lst))
tgt_md5 = set()
for col_vals in tgt.values():
lst = ["" if v is None else str(v) for v in col_vals]
tgt_md5.add(md5_list(lst))
tgt_size = len(tgt_md5)
inter = tgt_md5.intersection(std_md5)
if tgt_size == len(inter) and tgt_size == len(std_md5):
return DataCompareResult.right("compareResult success!")
if len(std_md5) == len(inter):
if cfg.strategy == "EXACT_MATCH":
return DataCompareResult.failed("compareResult failed!")
elif cfg.strategy == "CONTAIN_MATCH":
return DataCompareResult.right("compareResult success!")
return DataCompareResult.wrong("compareResult wrong!")
except Exception as e:
return DataCompareResult.exception(f"compareResult Exception! {e}")
def _compare_unordered(self, std: Dict[str, List[str]], cfg: DataCompareStrategyConfig, tgt: Dict[str, List[str]]) -> DataCompareResult:
try:
tgt_md5 = []
tgt_cols = []
for k, col_vals in tgt.items():
lst = ["" if v is None else str(v) for v in col_vals]
lst.sort()
tgt_md5.append(md5_list(lst))
tgt_cols.append(k)
for std_key, std_vals in std.items():
std_list = ["" if v is None else str(v) for v in std_vals]
std_list.sort()
std_md5 = md5_list(std_list)
if std_md5 not in tgt_md5:
return DataCompareResult.wrong("compareResult wrong!")
idx = tgt_md5.index(std_md5)
tgt_key = tgt_cols[idx]
std_sorted = sort_columns_by_key(std, std_key)
tgt_sorted = sort_columns_by_key(tgt, tgt_key)
ordered_cfg = DataCompareStrategyConfig(
strategy=cfg.strategy,
order_by=True,
standard_result=cfg.standard_result
)
res = self._compare_ordered(std_sorted, ordered_cfg, tgt_sorted)
if res.compare_result == DataCompareResultEnum.RIGHT:
return res
return DataCompareResult.wrong("compareResult wrong!")
except Exception as e:
return DataCompareResult.exception(f"compareResult Exception! {e}")
def compare_json_by_config(self, standard_answer: str, answer: str, compare_config: Dict[str, str]) -> DataCompareResult:
try:
if not standard_answer or not answer:
return DataCompareResult.failed("standardAnswer or answer is null")
ans = json.loads(answer)
for k, strat in compare_config.items():
if k not in ans:
return DataCompareResult.wrong("key missing")
if strat in ("FULL_TEXT", "ARRAY"):
if str(ans[k]) != "ok":
return DataCompareResult.wrong("value mismatch")
elif strat == "DAL":
return DataCompareResult.failed("DAL compare not supported in mock")
else:
return DataCompareResult.failed(f"unknown strategy {strat}")
return DataCompareResult.right("json compare success")
except Exception as e:
return DataCompareResult.exception(f"compareJsonByConfig Exception! {e}")

View File

@@ -0,0 +1,118 @@
import json
from typing import List
from models import BaseInputModel, AnswerExecuteModel, RoundAnswerConfirmModel, DataCompareResultEnum, DataCompareStrategyConfig
import pandas as pd
import os
class FileParseService:
def parse_input_sets(self, path: str) -> List[BaseInputModel]:
data = []
with open(path, "r", encoding="utf-8") as f:
for line in f:
if not line.strip(): continue
obj = json.loads(line)
data.append(BaseInputModel(
serialNo=obj["serialNo"],
analysisModelId=obj["analysisModelId"],
question=obj["question"],
selfDefineTags=obj.get("selfDefineTags"),
prompt=obj.get("prompt"),
knowledge=obj.get("knowledge"),
))
return data
def parse_llm_outputs(self, path: str) -> List[AnswerExecuteModel]:
data = []
with open(path, "r", encoding="utf-8") as f:
for line in f:
if not line.strip(): continue
obj = json.loads(line)
data.append(AnswerExecuteModel.from_dict(obj))
return data
def write_data_compare_result(self, path: str, round_id: int, confirm_models: List[RoundAnswerConfirmModel], is_execute: bool, llm_count: int):
if not path.endswith(".jsonl"):
raise ValueError(f"output_file_path must end with .jsonl, got {path}")
out_path = path.replace(".jsonl", f".round{round_id}.compare.jsonl")
with open(out_path, "w", encoding="utf-8") as f:
for cm in confirm_models:
row = dict(
serialNo=cm.serialNo,
analysisModelId=cm.analysisModelId,
question=cm.question,
selfDefineTags=cm.selfDefineTags,
prompt=cm.prompt,
standardAnswerSql=cm.standardAnswerSql,
llmOutput=cm.llmOutput,
executeResult=cm.executeResult,
errorMsg=cm.errorMsg,
compareResult=cm.compareResult.value if cm.compareResult else None,
isExecute=is_execute,
llmCount=llm_count
)
f.write(json.dumps(row, ensure_ascii=False) + "\n")
print(f"[write_data_compare_result] compare written to: {out_path}")
def summary_and_write_multi_round_benchmark_result(self, output_path: str, round_id: int) -> str:
if not output_path.endswith(".jsonl"):
raise ValueError(f"output_file_path must end with .jsonl, got {output_path}")
compare_path = output_path.replace(".jsonl", f".round{round_id}.compare.jsonl")
right, wrong, failed, exception = 0, 0, 0, 0
if os.path.exists(compare_path):
with open(compare_path, "r", encoding="utf-8") as f:
for line in f:
if not line.strip(): continue
obj = json.loads(line)
cr = obj.get("compareResult")
if cr == DataCompareResultEnum.RIGHT.value: right += 1
elif cr == DataCompareResultEnum.WRONG.value: wrong += 1
elif cr == DataCompareResultEnum.FAILED.value: failed += 1
elif cr == DataCompareResultEnum.EXCEPTION.value: exception += 1
else:
print(f"[summary] compare file not found: {compare_path}")
summary_path = output_path.replace(".jsonl", f".round{round_id}.summary.json")
result = dict(right=right, wrong=wrong, failed=failed, exception=exception)
with open(summary_path, "w", encoding="utf-8") as f:
json.dump(result, f, ensure_ascii=False, indent=2)
print(f"[summary] summary written to: {summary_path} -> {result}")
return json.dumps(result, ensure_ascii=False)
def parse_standard_benchmark_sets(self, standard_excel_path: str) -> List[AnswerExecuteModel]:
df = pd.read_excel(standard_excel_path, sheet_name="Sheet1")
outputs: List[AnswerExecuteModel] = []
for _, row in df.iterrows():
try:
serial_no = int(row["编号"])
except Exception:
continue
question = row.get("用户问题")
analysis_model_id = row.get("数据集ID")
llm_output = None if pd.isna(row.get("标准答案SQL")) else str(row.get("标准答案SQL"))
order_by = True
if not pd.isna(row.get("是否排序")):
try:
order_by = bool(int(row.get("是否排序")))
except Exception:
order_by = True
std_result = None
if not pd.isna(row.get("标准结果")):
try:
std_result = json.loads(row.get("标准结果"))
except Exception:
std_result = None
strategy_config = DataCompareStrategyConfig(
strategy="CONTAIN_MATCH",
order_by=order_by,
standard_result=[std_result] if std_result is not None else None # 使用 list
)
outputs.append(AnswerExecuteModel(
serialNo=serial_no,
analysisModelId=analysis_model_id,
question=question,
llmOutput=llm_output,
executeResult=std_result,
strategyConfig=strategy_config
))
return outputs

View File

@@ -0,0 +1,116 @@
# app/services/models.py
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, List, Optional
class BenchmarkModeTypeEnum(str, Enum):
BUILD = "BUILD"
EXECUTE = "EXECUTE"
@dataclass
class DataCompareStrategyConfig:
strategy: str # "EXACT_MATCH" | "CONTAIN_MATCH"
order_by: bool = True
standard_result: Optional[List[Dict[str, List[str]]]] = None # 改为 list[dict]
class DataCompareResultEnum(str, Enum):
RIGHT = "RIGHT"
WRONG = "WRONG"
FAILED = "FAILED"
EXCEPTION = "EXCEPTION"
@dataclass
class DataCompareResult:
compare_result: DataCompareResultEnum
msg: str = ""
@staticmethod
def right(msg=""): return DataCompareResult(DataCompareResultEnum.RIGHT, msg)
@staticmethod
def wrong(msg=""): return DataCompareResult(DataCompareResultEnum.WRONG, msg)
@staticmethod
def failed(msg=""): return DataCompareResult(DataCompareResultEnum.FAILED, msg)
@staticmethod
def exception(msg=""): return DataCompareResult(DataCompareResultEnum.EXCEPTION, msg)
@dataclass
class BaseInputModel:
serialNo: int
analysisModelId: str
question: str
selfDefineTags: Optional[str] = None
prompt: Optional[str] = None
knowledge: Optional[str] = None
@dataclass
class AnswerExecuteModel:
serialNo: int
analysisModelId: str
question: str
llmOutput: Optional[str]
executeResult: Optional[Dict[str, List[str]]]
errorMsg: Optional[str] = None
strategyConfig: Optional[DataCompareStrategyConfig] = None
cotTokens: Optional[Any] = None
@staticmethod
def from_dict(d: Dict[str, Any]) -> "AnswerExecuteModel":
cfg = d.get("strategyConfig")
strategy_config = None
if cfg:
std_list = cfg.get("standard_result")
strategy_config = DataCompareStrategyConfig(
strategy=cfg.get("strategy"),
order_by=cfg.get("order_by", True),
standard_result=std_list if isinstance(std_list, list) else None
)
return AnswerExecuteModel(
serialNo=d["serialNo"],
analysisModelId=d["analysisModelId"],
question=d["question"],
llmOutput=d.get("llmOutput"),
executeResult=d.get("executeResult"),
errorMsg=d.get("errorMsg"),
strategyConfig=strategy_config,
cotTokens=d.get("cotTokens"),
)
def to_dict(self) -> Dict[str, Any]:
cfg = None
if self.strategyConfig:
cfg = dict(
strategy=self.strategyConfig.strategy,
order_by=self.strategyConfig.order_by,
standard_result=self.strategyConfig.standard_result
)
return dict(
serialNo=self.serialNo,
analysisModelId=self.analysisModelId,
question=self.question,
llmOutput=self.llmOutput,
executeResult=self.executeResult,
errorMsg=self.errorMsg,
strategyConfig=cfg,
cotTokens=self.cotTokens
)
@dataclass
class RoundAnswerConfirmModel:
serialNo: int
analysisModelId: str
question: str
selfDefineTags: Optional[str]
prompt: Optional[str]
standardAnswerSql: Optional[str] = None
strategyConfig: Optional[DataCompareStrategyConfig] = None
llmOutput: Optional[str] = None
executeResult: Optional[Dict[str, List[str]]] = None
errorMsg: Optional[str] = None
compareResult: Optional[DataCompareResultEnum] = None
@dataclass
class BenchmarkExecuteConfig:
benchmarkModeType: BenchmarkModeTypeEnum
compareResultEnable: bool
standardFilePath: Optional[str] = None
compareConfig: Optional[Dict[str, str]] = None

View File

@@ -0,0 +1,65 @@
from file_parse_service import FileParseService
from data_compare_service import DataCompareService
from user_input_execute_service import UserInputExecuteService
from models import BenchmarkExecuteConfig, BenchmarkModeTypeEnum
def run_build_mode():
fps = FileParseService()
dcs = DataCompareService()
svc = UserInputExecuteService(fps, dcs)
inputs = fps.parse_input_sets("data/input_round1.jsonl")
left = fps.parse_llm_outputs("data/output_round1_modelA.jsonl")
right = fps.parse_llm_outputs("data/output_round1_modelB.jsonl")
config = BenchmarkExecuteConfig(
benchmarkModeType=BenchmarkModeTypeEnum.BUILD,
compareResultEnable=True,
standardFilePath=None,
compareConfig={"check":"FULL_TEXT"}
)
svc.post_dispatch(
round_id=1,
config=config,
inputs=inputs,
left_outputs=left,
right_outputs=right,
input_file_path="data/input_round1.jsonl",
output_file_path="data/output_round1_modelB.jsonl"
)
fps.summary_and_write_multi_round_benchmark_result("data/output_round1_modelB.jsonl", 1)
print("BUILD compare path:", "data/output_round1_modelB.round1.compare.jsonl")
def run_execute_mode():
fps = FileParseService()
dcs = DataCompareService()
svc = UserInputExecuteService(fps, dcs)
inputs = fps.parse_input_sets("data/input_round1.jsonl")
right = fps.parse_llm_outputs("data/output_execute_model.jsonl")
config = BenchmarkExecuteConfig(
benchmarkModeType=BenchmarkModeTypeEnum.EXECUTE,
compareResultEnable=True,
standardFilePath="data/standard_answers.xlsx",
compareConfig=None
)
svc.post_dispatch(
round_id=1,
config=config,
inputs=inputs,
left_outputs=[],
right_outputs=right,
input_file_path="data/input_round1.jsonl",
output_file_path="data/output_execute_model.jsonl"
)
fps.summary_and_write_multi_round_benchmark_result("data/output_execute_model.jsonl", 1)
print("EXECUTE compare path:", "data/output_execute_model.round1.compare.jsonl")
if __name__ == "__main__":
run_build_mode()
run_execute_mode()

View File

@@ -0,0 +1,108 @@
# app/services/user_input_execute_service.py
from typing import List
from models import (
BaseInputModel, AnswerExecuteModel, RoundAnswerConfirmModel,
BenchmarkExecuteConfig, BenchmarkModeTypeEnum, DataCompareResultEnum, DataCompareStrategyConfig
)
from file_parse_service import FileParseService
from data_compare_service import DataCompareService
class UserInputExecuteService:
def __init__(self, file_service: FileParseService, compare_service: DataCompareService):
self.file_service = file_service
self.compare_service = compare_service
def post_dispatch(
self,
round_id: int,
config: BenchmarkExecuteConfig,
inputs: List[BaseInputModel],
left_outputs: List[AnswerExecuteModel],
right_outputs: List[AnswerExecuteModel],
input_file_path: str,
output_file_path: str
):
try:
if config.benchmarkModeType == BenchmarkModeTypeEnum.BUILD and config.compareResultEnable:
if left_outputs and right_outputs:
self._execute_llm_compare_result(output_file_path, round_id, inputs, left_outputs, right_outputs, config)
elif config.benchmarkModeType == BenchmarkModeTypeEnum.EXECUTE and config.compareResultEnable:
if config.standardFilePath and right_outputs:
standard_sets = self.file_service.parse_standard_benchmark_sets(config.standardFilePath)
self._execute_llm_compare_result(output_file_path, 1, inputs, standard_sets, right_outputs, config)
except Exception as e:
print(f"[post_dispatch] compare error: {e}")
def _execute_llm_compare_result(
self,
location: str,
round_id: int,
inputs: List[BaseInputModel],
left_answers: List[AnswerExecuteModel],
right_answers: List[AnswerExecuteModel],
config: BenchmarkExecuteConfig
):
left_map = {a.serialNo: a for a in left_answers}
right_map = {a.serialNo: a for a in right_answers}
confirm_list: List[RoundAnswerConfirmModel] = []
for inp in inputs:
left = left_map.get(inp.serialNo)
right = right_map.get(inp.serialNo)
if left is None and right is None:
continue
strategy_cfg = None
standard_sql = None
if left is not None:
standard_sql = left.llmOutput
if config.benchmarkModeType == BenchmarkModeTypeEnum.EXECUTE:
strategy_cfg = left.strategyConfig
else:
standard_result_list = []
if left.executeResult:
standard_result_list.append(left.executeResult)
strategy_cfg = DataCompareStrategyConfig(
strategy="EXACT_MATCH",
order_by=True,
standard_result=standard_result_list if standard_result_list else None
)
if right is not None:
if config.compareConfig and isinstance(config.compareConfig, dict):
res = self.compare_service.compare_json_by_config(
left.llmOutput if left else "", right.llmOutput or "", config.compareConfig
)
compare_result = res.compare_result
else:
if strategy_cfg is None:
compare_result = DataCompareResultEnum.FAILED
else:
res = self.compare_service.compare(
left if left else AnswerExecuteModel(
serialNo=inp.serialNo,
analysisModelId=inp.analysisModelId,
question=inp.question,
llmOutput=None,
executeResult=None
),
right.executeResult
)
compare_result = res.compare_result
confirm = RoundAnswerConfirmModel(
serialNo=inp.serialNo,
analysisModelId=inp.analysisModelId,
question=inp.question,
selfDefineTags=inp.selfDefineTags,
prompt=inp.prompt,
standardAnswerSql=standard_sql,
strategyConfig=strategy_cfg,
llmOutput=right.llmOutput if right else None,
executeResult=right.executeResult if right else None,
errorMsg=right.errorMsg if right else None,
compareResult=compare_result
)
confirm_list.append(confirm)
self.file_service.write_data_compare_result(location, round_id, confirm_list, config.benchmarkModeType == BenchmarkModeTypeEnum.EXECUTE, 2)

View File

@@ -26,8 +26,8 @@ class BenchmarkDataConfig(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
cache_dir: str = "cache"
db_path: str = "benchmark_data.db"
table_mapping_file: Optional[str] = None
db_path: str = "pilot/benchmark_meta_data/benchmark_data.db"
table_mapping_file: str = "pilot/benchmark_meta_data/table_mapping.json"
cache_expiry_days: int = 1