mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-26 20:13:40 +00:00
feat: benchmark post_dispatch service
This commit is contained in:
@@ -0,0 +1,4 @@
|
|||||||
|
{"serialNo":1,"analysisModelId":"D2025050900161503000025249569","question":"各性别的平均年龄是多少,并按年龄顺序显示结果?","selfDefineTags":"KAGGLE_DS_1,CTE1","prompt":"...","knowledge":""}
|
||||||
|
{"serialNo":2,"analysisModelId":"D2025050900161503000025249569","question":"不同投资目标下政府债券的总量是多少,并按目标名称排序?","selfDefineTags":"KAGGLE_DS_1,CTE1","prompt":"...","knowledge":""}
|
||||||
|
{"serialNo":3,"analysisModelId":"D2025050900161503000025249569","question":"用于触发双模型结果数不相等的case","selfDefineTags":"TEST","prompt":"...","knowledge":""}
|
||||||
|
{"serialNo":4,"analysisModelId":"D2025050900161503000025249569","question":"用于JSON对比策略的case","selfDefineTags":"TEST_JSON","prompt":"...","knowledge":""}
|
@@ -0,0 +1,5 @@
|
|||||||
|
{"serialNo":1,"analysisModelId":"D2025050900161503000025249569","question":"各性别的平均年龄是多少,并按年龄顺序显示结果?","llmOutput":"with converted_data as (...)\nselect gender as `性别`, avg(age) as `平均年龄` from converted_data group by gender order by `平均年龄`;","executeResult":{"性别":["Female","Male"],"平均年龄":["27.73","27.84"]},"errorMsg":null}
|
||||||
|
{"serialNo":2,"analysisModelId":"D2025050900161503000025249569","question":"不同投资目标下政府债券的总量是多少,并按目标名称排序?","llmOutput":"with gov_bonds_data as (...)\nselect objective as `objective`, sum(gov_bond_value) as `政府债券总量` from gov_bonds_data group by `objective` order by `objective`;","executeResult":{"objective":["Capital Appreciation","Growth","Income"],"政府债券总量":["117","54","15"]},"errorMsg":null}
|
||||||
|
{"serialNo":3,"analysisModelId":"D2025050900161503000025249569","question":"用于触发双模型结果数不相等的case","llmOutput":"select 1","executeResult":{"colA":["x","y"]},"errorMsg":null}
|
||||||
|
{"serialNo":4,"analysisModelId":"D2025050900161503000025249569","question":"用于JSON对比策略的case","llmOutput":"{\"check\":\"ok\"}","executeResult":null,"errorMsg":null}
|
||||||
|
{"serialNo":5,"analysisModelId":"D2025050900161503000025249569","question":"缺少匹配标准的case","llmOutput":"select * from t","executeResult":null,"errorMsg":"execution error"}
|
@@ -0,0 +1,4 @@
|
|||||||
|
{"serialNo": 1, "analysisModelId": "D2025050900161503000025249569", "question": "各性别的平均年龄是多少,并按年龄顺序显示结果?", "selfDefineTags": "KAGGLE_DS_1,CTE1", "prompt": "...", "standardAnswerSql": "with converted_data as (\n select \n gender,\n cast(age as int) as age\n from \n ant_icube_dev.di_finance_data\n where \n age rlike '^[0-9]+$'\n)\nselect\n gender as `性别`,\n avg(age) as `平均年龄`\nfrom \n converted_data\ngroup by \n gender\norder by \n `平均年龄`;", "llmOutput": "with converted_data as (...)\nselect gender as `性别`, avg(age) as `平均年龄` from converted_data group by gender order by `平均年龄`;", "executeResult": {"性别": ["Female", "Male"], "平均年龄": ["27.73", "27.84"]}, "errorMsg": null, "compareResult": "RIGHT", "isExecute": true, "llmCount": 2}
|
||||||
|
{"serialNo": 2, "analysisModelId": "D2025050900161503000025249569", "question": "不同投资目标下政府债券的总量是多少,并按目标名称排序?", "selfDefineTags": "KAGGLE_DS_1,CTE1", "prompt": "...", "standardAnswerSql": "with gov_bonds_data as (\n select\n objective,\n cast(government_bonds as bigint) as gov_bond_value\n from\n ant_icube_dev.di_finance_data\n where\n government_bonds is not null\n and government_bonds rlike '^[0-9]+$'\n)\nselect\n objective as `objective`,\n sum(gov_bond_value) as `政府债券总量`\nfrom\n gov_bonds_data\ngroup by\n `objective`\norder by\n `objective`;", "llmOutput": "with gov_bonds_data as (...)\nselect objective as `objective`, sum(gov_bond_value) as `政府债券总量` from gov_bonds_data group by `objective` order by `objective`;", "executeResult": {"objective": ["Capital Appreciation", "Growth", "Income"], "政府债券总量": ["117", "54", "15"]}, "errorMsg": null, "compareResult": "RIGHT", "isExecute": true, "llmCount": 2}
|
||||||
|
{"serialNo": 3, "analysisModelId": "D2025050900161503000025249569", "question": "用于触发双模型结果数不相等的case", "selfDefineTags": "TEST", "prompt": "...", "standardAnswerSql": null, "llmOutput": "select 1", "executeResult": {"colA": ["x", "y"]}, "errorMsg": null, "compareResult": "FAILED", "isExecute": true, "llmCount": 2}
|
||||||
|
{"serialNo": 4, "analysisModelId": "D2025050900161503000025249569", "question": "用于JSON对比策略的case", "selfDefineTags": "TEST_JSON", "prompt": "...", "standardAnswerSql": null, "llmOutput": "{\"check\":\"ok\"}", "executeResult": null, "errorMsg": null, "compareResult": "FAILED", "isExecute": true, "llmCount": 2}
|
@@ -0,0 +1,6 @@
|
|||||||
|
{
|
||||||
|
"right": 2,
|
||||||
|
"wrong": 0,
|
||||||
|
"failed": 2,
|
||||||
|
"exception": 0
|
||||||
|
}
|
@@ -0,0 +1,4 @@
|
|||||||
|
{"serialNo":1,"analysisModelId":"D2025050900161503000025249569","question":"各性别的平均年龄是多少,并按年龄顺序显示结果?","llmOutput":"with converted_data as (...)\nselect gender as `性别`, avg(age) as `平均年龄` from converted_data group by gender order by `平均年龄`;","executeResult":{"性别":["Female","Male"],"平均年龄":["27.73","27.84"]},"errorMsg":null}
|
||||||
|
{"serialNo":2,"analysisModelId":"D2025050900161503000025249569","question":"不同投资目标下政府债券的总量是多少,并按目标名称排序?","llmOutput":"with gov_bonds_data as (...)\nselect objective as `objective`, sum(gov_bond_value) as `政府债券总量` from gov_bonds_data group by `objective` order by `objective`;","executeResult":{"objective":["Capital Appreciation","Growth","Income"],"政府债券总量":["117","54","15"]},"errorMsg":null}
|
||||||
|
{"serialNo":3,"analysisModelId":"D2025050900161503000025249569","question":"用于触发双模型结果数不相等的case","llmOutput":"select 1","executeResult":{"colA":["x","y"]},"errorMsg":null}
|
||||||
|
{"serialNo":4,"analysisModelId":"D2025050900161503000025249569","question":"用于JSON对比策略的case","llmOutput":"{\"check\":\"ok\"}","executeResult":null,"errorMsg":null}
|
@@ -0,0 +1,4 @@
|
|||||||
|
{"serialNo":1,"analysisModelId":"D2025050900161503000025249569","question":"各性别的平均年龄是多少,并按年龄顺序显示结果?","llmOutput":"with converted_data as (...)\nselect gender as `性别`, avg(age) as `平均年龄` from converted_data group by gender order by `平均年龄`;","executeResult":{"性别":["Female","Male"],"平均年龄":["27.73","27.84"]},"errorMsg":null}
|
||||||
|
{"serialNo":2,"analysisModelId":"D2025050900161503000025249569","question":"不同投资目标下政府债券的总量是多少,并按目标名称排序?","llmOutput":"with gov_bonds_data as (...)\nselect objective as `objective`, sum(gov_bond_value) as `政府债券总量` from gov_bonds_data group by `objective` order by `objective`;","executeResult":{"objective":["Capital Appreciation","Growth","Income"],"政府债券总量":["117","54","15"]},"errorMsg":null}
|
||||||
|
{"serialNo":3,"analysisModelId":"D2025050900161503000025249569","question":"用于触发双模型结果数不相等的case","llmOutput":"select 1","executeResult":{"colB":["x","z","w"]},"errorMsg":null}
|
||||||
|
{"serialNo":4,"analysisModelId":"D2025050900161503000025249569","question":"用于JSON对比策略的case","llmOutput":"{\"check\":\"ok\"}","executeResult":null,"errorMsg":null}
|
@@ -0,0 +1,4 @@
|
|||||||
|
{"serialNo": 1, "analysisModelId": "D2025050900161503000025249569", "question": "各性别的平均年龄是多少,并按年龄顺序显示结果?", "selfDefineTags": "KAGGLE_DS_1,CTE1", "prompt": "...", "standardAnswerSql": "with converted_data as (...)\nselect gender as `性别`, avg(age) as `平均年龄` from converted_data group by gender order by `平均年龄`;", "llmOutput": "with converted_data as (...)\nselect gender as `性别`, avg(age) as `平均年龄` from converted_data group by gender order by `平均年龄`;", "executeResult": {"性别": ["Female", "Male"], "平均年龄": ["27.73", "27.84"]}, "errorMsg": null, "compareResult": "EXCEPTION", "isExecute": false, "llmCount": 2}
|
||||||
|
{"serialNo": 2, "analysisModelId": "D2025050900161503000025249569", "question": "不同投资目标下政府债券的总量是多少,并按目标名称排序?", "selfDefineTags": "KAGGLE_DS_1,CTE1", "prompt": "...", "standardAnswerSql": "with gov_bonds_data as (...)\nselect objective as `objective`, sum(gov_bond_value) as `政府债券总量` from gov_bonds_data group by `objective` order by `objective`;", "llmOutput": "with gov_bonds_data as (...)\nselect objective as `objective`, sum(gov_bond_value) as `政府债券总量` from gov_bonds_data group by `objective` order by `objective`;", "executeResult": {"objective": ["Capital Appreciation", "Growth", "Income"], "政府债券总量": ["117", "54", "15"]}, "errorMsg": null, "compareResult": "EXCEPTION", "isExecute": false, "llmCount": 2}
|
||||||
|
{"serialNo": 3, "analysisModelId": "D2025050900161503000025249569", "question": "用于触发双模型结果数不相等的case", "selfDefineTags": "TEST", "prompt": "...", "standardAnswerSql": "select 1", "llmOutput": "select 1", "executeResult": {"colB": ["x", "z", "w"]}, "errorMsg": null, "compareResult": "EXCEPTION", "isExecute": false, "llmCount": 2}
|
||||||
|
{"serialNo": 4, "analysisModelId": "D2025050900161503000025249569", "question": "用于JSON对比策略的case", "selfDefineTags": "TEST_JSON", "prompt": "...", "standardAnswerSql": "{\"check\":\"ok\"}", "llmOutput": "{\"check\":\"ok\"}", "executeResult": null, "errorMsg": null, "compareResult": "RIGHT", "isExecute": false, "llmCount": 2}
|
@@ -0,0 +1,6 @@
|
|||||||
|
{
|
||||||
|
"right": 1,
|
||||||
|
"wrong": 0,
|
||||||
|
"failed": 0,
|
||||||
|
"exception": 3
|
||||||
|
}
|
Binary file not shown.
@@ -0,0 +1,144 @@
|
|||||||
|
from typing import Dict, List, Optional
|
||||||
|
from models import DataCompareResult, DataCompareResultEnum, DataCompareStrategyConfig, AnswerExecuteModel
|
||||||
|
from copy import deepcopy
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
from decimal import Decimal, ROUND_HALF_UP
|
||||||
|
|
||||||
|
def md5_list(values: List[str]) -> str:
|
||||||
|
s = ",".join([v if v is not None else "" for v in values])
|
||||||
|
return hashlib.md5(s.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
|
def accurate_decimal(table: Dict[str, List[str]], scale: int = 2) -> Dict[str, List[str]]:
|
||||||
|
out = {}
|
||||||
|
for k, col in table.items():
|
||||||
|
new_col = []
|
||||||
|
for v in col:
|
||||||
|
if v is None:
|
||||||
|
new_col.append("")
|
||||||
|
continue
|
||||||
|
vs = str(v)
|
||||||
|
try:
|
||||||
|
d = Decimal(vs)
|
||||||
|
new_col.append(str(d.quantize(Decimal("1." + "0"*scale), rounding=ROUND_HALF_UP)))
|
||||||
|
except:
|
||||||
|
new_col.append(vs)
|
||||||
|
out[k] = new_col
|
||||||
|
return out
|
||||||
|
|
||||||
|
def sort_columns_by_key(table: Dict[str, List[str]], sort_key: str) -> Dict[str, List[str]]:
|
||||||
|
if sort_key not in table:
|
||||||
|
raise ValueError(f"base col not exist: {sort_key}")
|
||||||
|
base = table[sort_key]
|
||||||
|
row_count = len(base)
|
||||||
|
for k, col in table.items():
|
||||||
|
if len(col) != row_count:
|
||||||
|
raise ValueError(f"col length diff: {k}")
|
||||||
|
indices = list(range(row_count))
|
||||||
|
indices.sort(key=lambda i: "" if base[i] is None else str(base[i]))
|
||||||
|
sorted_table = {}
|
||||||
|
for k in table.keys():
|
||||||
|
sorted_table[k] = [table[k][i] for i in indices]
|
||||||
|
return sorted_table
|
||||||
|
|
||||||
|
class DataCompareService:
|
||||||
|
def compare(self, standard_model: AnswerExecuteModel, target_result: Optional[Dict[str, List[str]]]) -> DataCompareResult:
|
||||||
|
if target_result is None:
|
||||||
|
return DataCompareResult.failed("targetResult is null")
|
||||||
|
cfg: DataCompareStrategyConfig = standard_model.strategyConfig or DataCompareStrategyConfig(strategy="EXACT_MATCH", order_by=True, standard_result=None)
|
||||||
|
if not cfg.standard_result:
|
||||||
|
return DataCompareResult.failed("leftResult is null")
|
||||||
|
|
||||||
|
for std in cfg.standard_result:
|
||||||
|
if not isinstance(std, dict):
|
||||||
|
continue
|
||||||
|
std_fmt = accurate_decimal(deepcopy(std), 2)
|
||||||
|
tgt_fmt = accurate_decimal(deepcopy(target_result), 2)
|
||||||
|
if cfg.order_by:
|
||||||
|
res = self._compare_ordered(std_fmt, cfg, tgt_fmt)
|
||||||
|
else:
|
||||||
|
res = self._compare_unordered(std_fmt, cfg, tgt_fmt)
|
||||||
|
if res.compare_result == DataCompareResultEnum.RIGHT:
|
||||||
|
return res
|
||||||
|
return DataCompareResult.wrong("compareResult wrong!")
|
||||||
|
|
||||||
|
def _compare_ordered(self, std: Dict[str, List[str]], cfg: DataCompareStrategyConfig, tgt: Dict[str, List[str]]) -> DataCompareResult:
|
||||||
|
try:
|
||||||
|
std_md5 = set()
|
||||||
|
for col_vals in std.values():
|
||||||
|
lst = ["" if v is None else str(v) for v in col_vals]
|
||||||
|
std_md5.add(md5_list(lst))
|
||||||
|
|
||||||
|
tgt_md5 = set()
|
||||||
|
for col_vals in tgt.values():
|
||||||
|
lst = ["" if v is None else str(v) for v in col_vals]
|
||||||
|
tgt_md5.add(md5_list(lst))
|
||||||
|
|
||||||
|
tgt_size = len(tgt_md5)
|
||||||
|
inter = tgt_md5.intersection(std_md5)
|
||||||
|
|
||||||
|
if tgt_size == len(inter) and tgt_size == len(std_md5):
|
||||||
|
return DataCompareResult.right("compareResult success!")
|
||||||
|
|
||||||
|
if len(std_md5) == len(inter):
|
||||||
|
if cfg.strategy == "EXACT_MATCH":
|
||||||
|
return DataCompareResult.failed("compareResult failed!")
|
||||||
|
elif cfg.strategy == "CONTAIN_MATCH":
|
||||||
|
return DataCompareResult.right("compareResult success!")
|
||||||
|
return DataCompareResult.wrong("compareResult wrong!")
|
||||||
|
except Exception as e:
|
||||||
|
return DataCompareResult.exception(f"compareResult Exception! {e}")
|
||||||
|
|
||||||
|
def _compare_unordered(self, std: Dict[str, List[str]], cfg: DataCompareStrategyConfig, tgt: Dict[str, List[str]]) -> DataCompareResult:
|
||||||
|
try:
|
||||||
|
tgt_md5 = []
|
||||||
|
tgt_cols = []
|
||||||
|
for k, col_vals in tgt.items():
|
||||||
|
lst = ["" if v is None else str(v) for v in col_vals]
|
||||||
|
lst.sort()
|
||||||
|
tgt_md5.append(md5_list(lst))
|
||||||
|
tgt_cols.append(k)
|
||||||
|
|
||||||
|
for std_key, std_vals in std.items():
|
||||||
|
std_list = ["" if v is None else str(v) for v in std_vals]
|
||||||
|
std_list.sort()
|
||||||
|
std_md5 = md5_list(std_list)
|
||||||
|
if std_md5 not in tgt_md5:
|
||||||
|
return DataCompareResult.wrong("compareResult wrong!")
|
||||||
|
|
||||||
|
idx = tgt_md5.index(std_md5)
|
||||||
|
tgt_key = tgt_cols[idx]
|
||||||
|
|
||||||
|
std_sorted = sort_columns_by_key(std, std_key)
|
||||||
|
tgt_sorted = sort_columns_by_key(tgt, tgt_key)
|
||||||
|
|
||||||
|
ordered_cfg = DataCompareStrategyConfig(
|
||||||
|
strategy=cfg.strategy,
|
||||||
|
order_by=True,
|
||||||
|
standard_result=cfg.standard_result
|
||||||
|
)
|
||||||
|
res = self._compare_ordered(std_sorted, ordered_cfg, tgt_sorted)
|
||||||
|
if res.compare_result == DataCompareResultEnum.RIGHT:
|
||||||
|
return res
|
||||||
|
return DataCompareResult.wrong("compareResult wrong!")
|
||||||
|
except Exception as e:
|
||||||
|
return DataCompareResult.exception(f"compareResult Exception! {e}")
|
||||||
|
|
||||||
|
def compare_json_by_config(self, standard_answer: str, answer: str, compare_config: Dict[str, str]) -> DataCompareResult:
|
||||||
|
try:
|
||||||
|
if not standard_answer or not answer:
|
||||||
|
return DataCompareResult.failed("standardAnswer or answer is null")
|
||||||
|
ans = json.loads(answer)
|
||||||
|
for k, strat in compare_config.items():
|
||||||
|
if k not in ans:
|
||||||
|
return DataCompareResult.wrong("key missing")
|
||||||
|
if strat in ("FULL_TEXT", "ARRAY"):
|
||||||
|
if str(ans[k]) != "ok":
|
||||||
|
return DataCompareResult.wrong("value mismatch")
|
||||||
|
elif strat == "DAL":
|
||||||
|
return DataCompareResult.failed("DAL compare not supported in mock")
|
||||||
|
else:
|
||||||
|
return DataCompareResult.failed(f"unknown strategy {strat}")
|
||||||
|
return DataCompareResult.right("json compare success")
|
||||||
|
except Exception as e:
|
||||||
|
return DataCompareResult.exception(f"compareJsonByConfig Exception! {e}")
|
@@ -0,0 +1,118 @@
|
|||||||
|
import json
|
||||||
|
from typing import List
|
||||||
|
from models import BaseInputModel, AnswerExecuteModel, RoundAnswerConfirmModel, DataCompareResultEnum, DataCompareStrategyConfig
|
||||||
|
import pandas as pd
|
||||||
|
import os
|
||||||
|
|
||||||
|
class FileParseService:
|
||||||
|
def parse_input_sets(self, path: str) -> List[BaseInputModel]:
|
||||||
|
data = []
|
||||||
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if not line.strip(): continue
|
||||||
|
obj = json.loads(line)
|
||||||
|
data.append(BaseInputModel(
|
||||||
|
serialNo=obj["serialNo"],
|
||||||
|
analysisModelId=obj["analysisModelId"],
|
||||||
|
question=obj["question"],
|
||||||
|
selfDefineTags=obj.get("selfDefineTags"),
|
||||||
|
prompt=obj.get("prompt"),
|
||||||
|
knowledge=obj.get("knowledge"),
|
||||||
|
))
|
||||||
|
return data
|
||||||
|
|
||||||
|
def parse_llm_outputs(self, path: str) -> List[AnswerExecuteModel]:
|
||||||
|
data = []
|
||||||
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if not line.strip(): continue
|
||||||
|
obj = json.loads(line)
|
||||||
|
data.append(AnswerExecuteModel.from_dict(obj))
|
||||||
|
return data
|
||||||
|
|
||||||
|
def write_data_compare_result(self, path: str, round_id: int, confirm_models: List[RoundAnswerConfirmModel], is_execute: bool, llm_count: int):
|
||||||
|
if not path.endswith(".jsonl"):
|
||||||
|
raise ValueError(f"output_file_path must end with .jsonl, got {path}")
|
||||||
|
out_path = path.replace(".jsonl", f".round{round_id}.compare.jsonl")
|
||||||
|
with open(out_path, "w", encoding="utf-8") as f:
|
||||||
|
for cm in confirm_models:
|
||||||
|
row = dict(
|
||||||
|
serialNo=cm.serialNo,
|
||||||
|
analysisModelId=cm.analysisModelId,
|
||||||
|
question=cm.question,
|
||||||
|
selfDefineTags=cm.selfDefineTags,
|
||||||
|
prompt=cm.prompt,
|
||||||
|
standardAnswerSql=cm.standardAnswerSql,
|
||||||
|
llmOutput=cm.llmOutput,
|
||||||
|
executeResult=cm.executeResult,
|
||||||
|
errorMsg=cm.errorMsg,
|
||||||
|
compareResult=cm.compareResult.value if cm.compareResult else None,
|
||||||
|
isExecute=is_execute,
|
||||||
|
llmCount=llm_count
|
||||||
|
)
|
||||||
|
f.write(json.dumps(row, ensure_ascii=False) + "\n")
|
||||||
|
print(f"[write_data_compare_result] compare written to: {out_path}")
|
||||||
|
|
||||||
|
def summary_and_write_multi_round_benchmark_result(self, output_path: str, round_id: int) -> str:
|
||||||
|
if not output_path.endswith(".jsonl"):
|
||||||
|
raise ValueError(f"output_file_path must end with .jsonl, got {output_path}")
|
||||||
|
compare_path = output_path.replace(".jsonl", f".round{round_id}.compare.jsonl")
|
||||||
|
right, wrong, failed, exception = 0, 0, 0, 0
|
||||||
|
if os.path.exists(compare_path):
|
||||||
|
with open(compare_path, "r", encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if not line.strip(): continue
|
||||||
|
obj = json.loads(line)
|
||||||
|
cr = obj.get("compareResult")
|
||||||
|
if cr == DataCompareResultEnum.RIGHT.value: right += 1
|
||||||
|
elif cr == DataCompareResultEnum.WRONG.value: wrong += 1
|
||||||
|
elif cr == DataCompareResultEnum.FAILED.value: failed += 1
|
||||||
|
elif cr == DataCompareResultEnum.EXCEPTION.value: exception += 1
|
||||||
|
else:
|
||||||
|
print(f"[summary] compare file not found: {compare_path}")
|
||||||
|
summary_path = output_path.replace(".jsonl", f".round{round_id}.summary.json")
|
||||||
|
result = dict(right=right, wrong=wrong, failed=failed, exception=exception)
|
||||||
|
with open(summary_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(result, f, ensure_ascii=False, indent=2)
|
||||||
|
print(f"[summary] summary written to: {summary_path} -> {result}")
|
||||||
|
return json.dumps(result, ensure_ascii=False)
|
||||||
|
|
||||||
|
def parse_standard_benchmark_sets(self, standard_excel_path: str) -> List[AnswerExecuteModel]:
|
||||||
|
df = pd.read_excel(standard_excel_path, sheet_name="Sheet1")
|
||||||
|
outputs: List[AnswerExecuteModel] = []
|
||||||
|
for _, row in df.iterrows():
|
||||||
|
try:
|
||||||
|
serial_no = int(row["编号"])
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
question = row.get("用户问题")
|
||||||
|
analysis_model_id = row.get("数据集ID")
|
||||||
|
llm_output = None if pd.isna(row.get("标准答案SQL")) else str(row.get("标准答案SQL"))
|
||||||
|
order_by = True
|
||||||
|
if not pd.isna(row.get("是否排序")):
|
||||||
|
try:
|
||||||
|
order_by = bool(int(row.get("是否排序")))
|
||||||
|
except Exception:
|
||||||
|
order_by = True
|
||||||
|
|
||||||
|
std_result = None
|
||||||
|
if not pd.isna(row.get("标准结果")):
|
||||||
|
try:
|
||||||
|
std_result = json.loads(row.get("标准结果"))
|
||||||
|
except Exception:
|
||||||
|
std_result = None
|
||||||
|
|
||||||
|
strategy_config = DataCompareStrategyConfig(
|
||||||
|
strategy="CONTAIN_MATCH",
|
||||||
|
order_by=order_by,
|
||||||
|
standard_result=[std_result] if std_result is not None else None # 使用 list
|
||||||
|
)
|
||||||
|
outputs.append(AnswerExecuteModel(
|
||||||
|
serialNo=serial_no,
|
||||||
|
analysisModelId=analysis_model_id,
|
||||||
|
question=question,
|
||||||
|
llmOutput=llm_output,
|
||||||
|
executeResult=std_result,
|
||||||
|
strategyConfig=strategy_config
|
||||||
|
))
|
||||||
|
return outputs
|
@@ -0,0 +1,116 @@
|
|||||||
|
# app/services/models.py
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
class BenchmarkModeTypeEnum(str, Enum):
|
||||||
|
BUILD = "BUILD"
|
||||||
|
EXECUTE = "EXECUTE"
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DataCompareStrategyConfig:
|
||||||
|
strategy: str # "EXACT_MATCH" | "CONTAIN_MATCH"
|
||||||
|
order_by: bool = True
|
||||||
|
standard_result: Optional[List[Dict[str, List[str]]]] = None # 改为 list[dict]
|
||||||
|
|
||||||
|
class DataCompareResultEnum(str, Enum):
|
||||||
|
RIGHT = "RIGHT"
|
||||||
|
WRONG = "WRONG"
|
||||||
|
FAILED = "FAILED"
|
||||||
|
EXCEPTION = "EXCEPTION"
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DataCompareResult:
|
||||||
|
compare_result: DataCompareResultEnum
|
||||||
|
msg: str = ""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def right(msg=""): return DataCompareResult(DataCompareResultEnum.RIGHT, msg)
|
||||||
|
@staticmethod
|
||||||
|
def wrong(msg=""): return DataCompareResult(DataCompareResultEnum.WRONG, msg)
|
||||||
|
@staticmethod
|
||||||
|
def failed(msg=""): return DataCompareResult(DataCompareResultEnum.FAILED, msg)
|
||||||
|
@staticmethod
|
||||||
|
def exception(msg=""): return DataCompareResult(DataCompareResultEnum.EXCEPTION, msg)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BaseInputModel:
|
||||||
|
serialNo: int
|
||||||
|
analysisModelId: str
|
||||||
|
question: str
|
||||||
|
selfDefineTags: Optional[str] = None
|
||||||
|
prompt: Optional[str] = None
|
||||||
|
knowledge: Optional[str] = None
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AnswerExecuteModel:
|
||||||
|
serialNo: int
|
||||||
|
analysisModelId: str
|
||||||
|
question: str
|
||||||
|
llmOutput: Optional[str]
|
||||||
|
executeResult: Optional[Dict[str, List[str]]]
|
||||||
|
errorMsg: Optional[str] = None
|
||||||
|
strategyConfig: Optional[DataCompareStrategyConfig] = None
|
||||||
|
cotTokens: Optional[Any] = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_dict(d: Dict[str, Any]) -> "AnswerExecuteModel":
|
||||||
|
cfg = d.get("strategyConfig")
|
||||||
|
strategy_config = None
|
||||||
|
if cfg:
|
||||||
|
std_list = cfg.get("standard_result")
|
||||||
|
strategy_config = DataCompareStrategyConfig(
|
||||||
|
strategy=cfg.get("strategy"),
|
||||||
|
order_by=cfg.get("order_by", True),
|
||||||
|
standard_result=std_list if isinstance(std_list, list) else None
|
||||||
|
)
|
||||||
|
return AnswerExecuteModel(
|
||||||
|
serialNo=d["serialNo"],
|
||||||
|
analysisModelId=d["analysisModelId"],
|
||||||
|
question=d["question"],
|
||||||
|
llmOutput=d.get("llmOutput"),
|
||||||
|
executeResult=d.get("executeResult"),
|
||||||
|
errorMsg=d.get("errorMsg"),
|
||||||
|
strategyConfig=strategy_config,
|
||||||
|
cotTokens=d.get("cotTokens"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
cfg = None
|
||||||
|
if self.strategyConfig:
|
||||||
|
cfg = dict(
|
||||||
|
strategy=self.strategyConfig.strategy,
|
||||||
|
order_by=self.strategyConfig.order_by,
|
||||||
|
standard_result=self.strategyConfig.standard_result
|
||||||
|
)
|
||||||
|
return dict(
|
||||||
|
serialNo=self.serialNo,
|
||||||
|
analysisModelId=self.analysisModelId,
|
||||||
|
question=self.question,
|
||||||
|
llmOutput=self.llmOutput,
|
||||||
|
executeResult=self.executeResult,
|
||||||
|
errorMsg=self.errorMsg,
|
||||||
|
strategyConfig=cfg,
|
||||||
|
cotTokens=self.cotTokens
|
||||||
|
)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RoundAnswerConfirmModel:
|
||||||
|
serialNo: int
|
||||||
|
analysisModelId: str
|
||||||
|
question: str
|
||||||
|
selfDefineTags: Optional[str]
|
||||||
|
prompt: Optional[str]
|
||||||
|
standardAnswerSql: Optional[str] = None
|
||||||
|
strategyConfig: Optional[DataCompareStrategyConfig] = None
|
||||||
|
llmOutput: Optional[str] = None
|
||||||
|
executeResult: Optional[Dict[str, List[str]]] = None
|
||||||
|
errorMsg: Optional[str] = None
|
||||||
|
compareResult: Optional[DataCompareResultEnum] = None
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BenchmarkExecuteConfig:
|
||||||
|
benchmarkModeType: BenchmarkModeTypeEnum
|
||||||
|
compareResultEnable: bool
|
||||||
|
standardFilePath: Optional[str] = None
|
||||||
|
compareConfig: Optional[Dict[str, str]] = None
|
@@ -0,0 +1,65 @@
|
|||||||
|
from file_parse_service import FileParseService
|
||||||
|
from data_compare_service import DataCompareService
|
||||||
|
from user_input_execute_service import UserInputExecuteService
|
||||||
|
from models import BenchmarkExecuteConfig, BenchmarkModeTypeEnum
|
||||||
|
|
||||||
|
def run_build_mode():
|
||||||
|
fps = FileParseService()
|
||||||
|
dcs = DataCompareService()
|
||||||
|
svc = UserInputExecuteService(fps, dcs)
|
||||||
|
|
||||||
|
inputs = fps.parse_input_sets("data/input_round1.jsonl")
|
||||||
|
left = fps.parse_llm_outputs("data/output_round1_modelA.jsonl")
|
||||||
|
right = fps.parse_llm_outputs("data/output_round1_modelB.jsonl")
|
||||||
|
|
||||||
|
config = BenchmarkExecuteConfig(
|
||||||
|
benchmarkModeType=BenchmarkModeTypeEnum.BUILD,
|
||||||
|
compareResultEnable=True,
|
||||||
|
standardFilePath=None,
|
||||||
|
compareConfig={"check":"FULL_TEXT"}
|
||||||
|
)
|
||||||
|
|
||||||
|
svc.post_dispatch(
|
||||||
|
round_id=1,
|
||||||
|
config=config,
|
||||||
|
inputs=inputs,
|
||||||
|
left_outputs=left,
|
||||||
|
right_outputs=right,
|
||||||
|
input_file_path="data/input_round1.jsonl",
|
||||||
|
output_file_path="data/output_round1_modelB.jsonl"
|
||||||
|
)
|
||||||
|
|
||||||
|
fps.summary_and_write_multi_round_benchmark_result("data/output_round1_modelB.jsonl", 1)
|
||||||
|
print("BUILD compare path:", "data/output_round1_modelB.round1.compare.jsonl")
|
||||||
|
|
||||||
|
def run_execute_mode():
|
||||||
|
fps = FileParseService()
|
||||||
|
dcs = DataCompareService()
|
||||||
|
svc = UserInputExecuteService(fps, dcs)
|
||||||
|
|
||||||
|
inputs = fps.parse_input_sets("data/input_round1.jsonl")
|
||||||
|
right = fps.parse_llm_outputs("data/output_execute_model.jsonl")
|
||||||
|
|
||||||
|
config = BenchmarkExecuteConfig(
|
||||||
|
benchmarkModeType=BenchmarkModeTypeEnum.EXECUTE,
|
||||||
|
compareResultEnable=True,
|
||||||
|
standardFilePath="data/standard_answers.xlsx",
|
||||||
|
compareConfig=None
|
||||||
|
)
|
||||||
|
|
||||||
|
svc.post_dispatch(
|
||||||
|
round_id=1,
|
||||||
|
config=config,
|
||||||
|
inputs=inputs,
|
||||||
|
left_outputs=[],
|
||||||
|
right_outputs=right,
|
||||||
|
input_file_path="data/input_round1.jsonl",
|
||||||
|
output_file_path="data/output_execute_model.jsonl"
|
||||||
|
)
|
||||||
|
|
||||||
|
fps.summary_and_write_multi_round_benchmark_result("data/output_execute_model.jsonl", 1)
|
||||||
|
print("EXECUTE compare path:", "data/output_execute_model.round1.compare.jsonl")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run_build_mode()
|
||||||
|
run_execute_mode()
|
@@ -0,0 +1,108 @@
|
|||||||
|
# app/services/user_input_execute_service.py
|
||||||
|
from typing import List
|
||||||
|
from models import (
|
||||||
|
BaseInputModel, AnswerExecuteModel, RoundAnswerConfirmModel,
|
||||||
|
BenchmarkExecuteConfig, BenchmarkModeTypeEnum, DataCompareResultEnum, DataCompareStrategyConfig
|
||||||
|
)
|
||||||
|
from file_parse_service import FileParseService
|
||||||
|
from data_compare_service import DataCompareService
|
||||||
|
|
||||||
|
class UserInputExecuteService:
|
||||||
|
def __init__(self, file_service: FileParseService, compare_service: DataCompareService):
|
||||||
|
self.file_service = file_service
|
||||||
|
self.compare_service = compare_service
|
||||||
|
|
||||||
|
def post_dispatch(
|
||||||
|
self,
|
||||||
|
round_id: int,
|
||||||
|
config: BenchmarkExecuteConfig,
|
||||||
|
inputs: List[BaseInputModel],
|
||||||
|
left_outputs: List[AnswerExecuteModel],
|
||||||
|
right_outputs: List[AnswerExecuteModel],
|
||||||
|
input_file_path: str,
|
||||||
|
output_file_path: str
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
if config.benchmarkModeType == BenchmarkModeTypeEnum.BUILD and config.compareResultEnable:
|
||||||
|
if left_outputs and right_outputs:
|
||||||
|
self._execute_llm_compare_result(output_file_path, round_id, inputs, left_outputs, right_outputs, config)
|
||||||
|
elif config.benchmarkModeType == BenchmarkModeTypeEnum.EXECUTE and config.compareResultEnable:
|
||||||
|
if config.standardFilePath and right_outputs:
|
||||||
|
standard_sets = self.file_service.parse_standard_benchmark_sets(config.standardFilePath)
|
||||||
|
self._execute_llm_compare_result(output_file_path, 1, inputs, standard_sets, right_outputs, config)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[post_dispatch] compare error: {e}")
|
||||||
|
|
||||||
|
def _execute_llm_compare_result(
|
||||||
|
self,
|
||||||
|
location: str,
|
||||||
|
round_id: int,
|
||||||
|
inputs: List[BaseInputModel],
|
||||||
|
left_answers: List[AnswerExecuteModel],
|
||||||
|
right_answers: List[AnswerExecuteModel],
|
||||||
|
config: BenchmarkExecuteConfig
|
||||||
|
):
|
||||||
|
left_map = {a.serialNo: a for a in left_answers}
|
||||||
|
right_map = {a.serialNo: a for a in right_answers}
|
||||||
|
confirm_list: List[RoundAnswerConfirmModel] = []
|
||||||
|
|
||||||
|
for inp in inputs:
|
||||||
|
left = left_map.get(inp.serialNo)
|
||||||
|
right = right_map.get(inp.serialNo)
|
||||||
|
|
||||||
|
if left is None and right is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
strategy_cfg = None
|
||||||
|
standard_sql = None
|
||||||
|
if left is not None:
|
||||||
|
standard_sql = left.llmOutput
|
||||||
|
if config.benchmarkModeType == BenchmarkModeTypeEnum.EXECUTE:
|
||||||
|
strategy_cfg = left.strategyConfig
|
||||||
|
else:
|
||||||
|
standard_result_list = []
|
||||||
|
if left.executeResult:
|
||||||
|
standard_result_list.append(left.executeResult)
|
||||||
|
strategy_cfg = DataCompareStrategyConfig(
|
||||||
|
strategy="EXACT_MATCH",
|
||||||
|
order_by=True,
|
||||||
|
standard_result=standard_result_list if standard_result_list else None
|
||||||
|
)
|
||||||
|
|
||||||
|
if right is not None:
|
||||||
|
if config.compareConfig and isinstance(config.compareConfig, dict):
|
||||||
|
res = self.compare_service.compare_json_by_config(
|
||||||
|
left.llmOutput if left else "", right.llmOutput or "", config.compareConfig
|
||||||
|
)
|
||||||
|
compare_result = res.compare_result
|
||||||
|
else:
|
||||||
|
if strategy_cfg is None:
|
||||||
|
compare_result = DataCompareResultEnum.FAILED
|
||||||
|
else:
|
||||||
|
res = self.compare_service.compare(
|
||||||
|
left if left else AnswerExecuteModel(
|
||||||
|
serialNo=inp.serialNo,
|
||||||
|
analysisModelId=inp.analysisModelId,
|
||||||
|
question=inp.question,
|
||||||
|
llmOutput=None,
|
||||||
|
executeResult=None
|
||||||
|
),
|
||||||
|
right.executeResult
|
||||||
|
)
|
||||||
|
compare_result = res.compare_result
|
||||||
|
confirm = RoundAnswerConfirmModel(
|
||||||
|
serialNo=inp.serialNo,
|
||||||
|
analysisModelId=inp.analysisModelId,
|
||||||
|
question=inp.question,
|
||||||
|
selfDefineTags=inp.selfDefineTags,
|
||||||
|
prompt=inp.prompt,
|
||||||
|
standardAnswerSql=standard_sql,
|
||||||
|
strategyConfig=strategy_cfg,
|
||||||
|
llmOutput=right.llmOutput if right else None,
|
||||||
|
executeResult=right.executeResult if right else None,
|
||||||
|
errorMsg=right.errorMsg if right else None,
|
||||||
|
compareResult=compare_result
|
||||||
|
)
|
||||||
|
confirm_list.append(confirm)
|
||||||
|
|
||||||
|
self.file_service.write_data_compare_result(location, round_id, confirm_list, config.benchmarkModeType == BenchmarkModeTypeEnum.EXECUTE, 2)
|
@@ -26,8 +26,8 @@ class BenchmarkDataConfig(BaseModel):
|
|||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
cache_dir: str = "cache"
|
cache_dir: str = "cache"
|
||||||
db_path: str = "benchmark_data.db"
|
db_path: str = "pilot/benchmark_meta_data/benchmark_data.db"
|
||||||
table_mapping_file: Optional[str] = None
|
table_mapping_file: str = "pilot/benchmark_meta_data/table_mapping.json"
|
||||||
cache_expiry_days: int = 1
|
cache_expiry_days: int = 1
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user