mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-26 20:13:40 +00:00
feat: benchmark post_dispatch service
This commit is contained in:
@@ -0,0 +1,4 @@
|
||||
{"serialNo":1,"analysisModelId":"D2025050900161503000025249569","question":"各性别的平均年龄是多少,并按年龄顺序显示结果?","selfDefineTags":"KAGGLE_DS_1,CTE1","prompt":"...","knowledge":""}
|
||||
{"serialNo":2,"analysisModelId":"D2025050900161503000025249569","question":"不同投资目标下政府债券的总量是多少,并按目标名称排序?","selfDefineTags":"KAGGLE_DS_1,CTE1","prompt":"...","knowledge":""}
|
||||
{"serialNo":3,"analysisModelId":"D2025050900161503000025249569","question":"用于触发双模型结果数不相等的case","selfDefineTags":"TEST","prompt":"...","knowledge":""}
|
||||
{"serialNo":4,"analysisModelId":"D2025050900161503000025249569","question":"用于JSON对比策略的case","selfDefineTags":"TEST_JSON","prompt":"...","knowledge":""}
|
@@ -0,0 +1,5 @@
|
||||
{"serialNo":1,"analysisModelId":"D2025050900161503000025249569","question":"各性别的平均年龄是多少,并按年龄顺序显示结果?","llmOutput":"with converted_data as (...)\nselect gender as `性别`, avg(age) as `平均年龄` from converted_data group by gender order by `平均年龄`;","executeResult":{"性别":["Female","Male"],"平均年龄":["27.73","27.84"]},"errorMsg":null}
|
||||
{"serialNo":2,"analysisModelId":"D2025050900161503000025249569","question":"不同投资目标下政府债券的总量是多少,并按目标名称排序?","llmOutput":"with gov_bonds_data as (...)\nselect objective as `objective`, sum(gov_bond_value) as `政府债券总量` from gov_bonds_data group by `objective` order by `objective`;","executeResult":{"objective":["Capital Appreciation","Growth","Income"],"政府债券总量":["117","54","15"]},"errorMsg":null}
|
||||
{"serialNo":3,"analysisModelId":"D2025050900161503000025249569","question":"用于触发双模型结果数不相等的case","llmOutput":"select 1","executeResult":{"colA":["x","y"]},"errorMsg":null}
|
||||
{"serialNo":4,"analysisModelId":"D2025050900161503000025249569","question":"用于JSON对比策略的case","llmOutput":"{\"check\":\"ok\"}","executeResult":null,"errorMsg":null}
|
||||
{"serialNo":5,"analysisModelId":"D2025050900161503000025249569","question":"缺少匹配标准的case","llmOutput":"select * from t","executeResult":null,"errorMsg":"execution error"}
|
@@ -0,0 +1,4 @@
|
||||
{"serialNo": 1, "analysisModelId": "D2025050900161503000025249569", "question": "各性别的平均年龄是多少,并按年龄顺序显示结果?", "selfDefineTags": "KAGGLE_DS_1,CTE1", "prompt": "...", "standardAnswerSql": "with converted_data as (\n select \n gender,\n cast(age as int) as age\n from \n ant_icube_dev.di_finance_data\n where \n age rlike '^[0-9]+$'\n)\nselect\n gender as `性别`,\n avg(age) as `平均年龄`\nfrom \n converted_data\ngroup by \n gender\norder by \n `平均年龄`;", "llmOutput": "with converted_data as (...)\nselect gender as `性别`, avg(age) as `平均年龄` from converted_data group by gender order by `平均年龄`;", "executeResult": {"性别": ["Female", "Male"], "平均年龄": ["27.73", "27.84"]}, "errorMsg": null, "compareResult": "RIGHT", "isExecute": true, "llmCount": 2}
|
||||
{"serialNo": 2, "analysisModelId": "D2025050900161503000025249569", "question": "不同投资目标下政府债券的总量是多少,并按目标名称排序?", "selfDefineTags": "KAGGLE_DS_1,CTE1", "prompt": "...", "standardAnswerSql": "with gov_bonds_data as (\n select\n objective,\n cast(government_bonds as bigint) as gov_bond_value\n from\n ant_icube_dev.di_finance_data\n where\n government_bonds is not null\n and government_bonds rlike '^[0-9]+$'\n)\nselect\n objective as `objective`,\n sum(gov_bond_value) as `政府债券总量`\nfrom\n gov_bonds_data\ngroup by\n `objective`\norder by\n `objective`;", "llmOutput": "with gov_bonds_data as (...)\nselect objective as `objective`, sum(gov_bond_value) as `政府债券总量` from gov_bonds_data group by `objective` order by `objective`;", "executeResult": {"objective": ["Capital Appreciation", "Growth", "Income"], "政府债券总量": ["117", "54", "15"]}, "errorMsg": null, "compareResult": "RIGHT", "isExecute": true, "llmCount": 2}
|
||||
{"serialNo": 3, "analysisModelId": "D2025050900161503000025249569", "question": "用于触发双模型结果数不相等的case", "selfDefineTags": "TEST", "prompt": "...", "standardAnswerSql": null, "llmOutput": "select 1", "executeResult": {"colA": ["x", "y"]}, "errorMsg": null, "compareResult": "FAILED", "isExecute": true, "llmCount": 2}
|
||||
{"serialNo": 4, "analysisModelId": "D2025050900161503000025249569", "question": "用于JSON对比策略的case", "selfDefineTags": "TEST_JSON", "prompt": "...", "standardAnswerSql": null, "llmOutput": "{\"check\":\"ok\"}", "executeResult": null, "errorMsg": null, "compareResult": "FAILED", "isExecute": true, "llmCount": 2}
|
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"right": 2,
|
||||
"wrong": 0,
|
||||
"failed": 2,
|
||||
"exception": 0
|
||||
}
|
@@ -0,0 +1,4 @@
|
||||
{"serialNo":1,"analysisModelId":"D2025050900161503000025249569","question":"各性别的平均年龄是多少,并按年龄顺序显示结果?","llmOutput":"with converted_data as (...)\nselect gender as `性别`, avg(age) as `平均年龄` from converted_data group by gender order by `平均年龄`;","executeResult":{"性别":["Female","Male"],"平均年龄":["27.73","27.84"]},"errorMsg":null}
|
||||
{"serialNo":2,"analysisModelId":"D2025050900161503000025249569","question":"不同投资目标下政府债券的总量是多少,并按目标名称排序?","llmOutput":"with gov_bonds_data as (...)\nselect objective as `objective`, sum(gov_bond_value) as `政府债券总量` from gov_bonds_data group by `objective` order by `objective`;","executeResult":{"objective":["Capital Appreciation","Growth","Income"],"政府债券总量":["117","54","15"]},"errorMsg":null}
|
||||
{"serialNo":3,"analysisModelId":"D2025050900161503000025249569","question":"用于触发双模型结果数不相等的case","llmOutput":"select 1","executeResult":{"colA":["x","y"]},"errorMsg":null}
|
||||
{"serialNo":4,"analysisModelId":"D2025050900161503000025249569","question":"用于JSON对比策略的case","llmOutput":"{\"check\":\"ok\"}","executeResult":null,"errorMsg":null}
|
@@ -0,0 +1,4 @@
|
||||
{"serialNo":1,"analysisModelId":"D2025050900161503000025249569","question":"各性别的平均年龄是多少,并按年龄顺序显示结果?","llmOutput":"with converted_data as (...)\nselect gender as `性别`, avg(age) as `平均年龄` from converted_data group by gender order by `平均年龄`;","executeResult":{"性别":["Female","Male"],"平均年龄":["27.73","27.84"]},"errorMsg":null}
|
||||
{"serialNo":2,"analysisModelId":"D2025050900161503000025249569","question":"不同投资目标下政府债券的总量是多少,并按目标名称排序?","llmOutput":"with gov_bonds_data as (...)\nselect objective as `objective`, sum(gov_bond_value) as `政府债券总量` from gov_bonds_data group by `objective` order by `objective`;","executeResult":{"objective":["Capital Appreciation","Growth","Income"],"政府债券总量":["117","54","15"]},"errorMsg":null}
|
||||
{"serialNo":3,"analysisModelId":"D2025050900161503000025249569","question":"用于触发双模型结果数不相等的case","llmOutput":"select 1","executeResult":{"colB":["x","z","w"]},"errorMsg":null}
|
||||
{"serialNo":4,"analysisModelId":"D2025050900161503000025249569","question":"用于JSON对比策略的case","llmOutput":"{\"check\":\"ok\"}","executeResult":null,"errorMsg":null}
|
@@ -0,0 +1,4 @@
|
||||
{"serialNo": 1, "analysisModelId": "D2025050900161503000025249569", "question": "各性别的平均年龄是多少,并按年龄顺序显示结果?", "selfDefineTags": "KAGGLE_DS_1,CTE1", "prompt": "...", "standardAnswerSql": "with converted_data as (...)\nselect gender as `性别`, avg(age) as `平均年龄` from converted_data group by gender order by `平均年龄`;", "llmOutput": "with converted_data as (...)\nselect gender as `性别`, avg(age) as `平均年龄` from converted_data group by gender order by `平均年龄`;", "executeResult": {"性别": ["Female", "Male"], "平均年龄": ["27.73", "27.84"]}, "errorMsg": null, "compareResult": "EXCEPTION", "isExecute": false, "llmCount": 2}
|
||||
{"serialNo": 2, "analysisModelId": "D2025050900161503000025249569", "question": "不同投资目标下政府债券的总量是多少,并按目标名称排序?", "selfDefineTags": "KAGGLE_DS_1,CTE1", "prompt": "...", "standardAnswerSql": "with gov_bonds_data as (...)\nselect objective as `objective`, sum(gov_bond_value) as `政府债券总量` from gov_bonds_data group by `objective` order by `objective`;", "llmOutput": "with gov_bonds_data as (...)\nselect objective as `objective`, sum(gov_bond_value) as `政府债券总量` from gov_bonds_data group by `objective` order by `objective`;", "executeResult": {"objective": ["Capital Appreciation", "Growth", "Income"], "政府债券总量": ["117", "54", "15"]}, "errorMsg": null, "compareResult": "EXCEPTION", "isExecute": false, "llmCount": 2}
|
||||
{"serialNo": 3, "analysisModelId": "D2025050900161503000025249569", "question": "用于触发双模型结果数不相等的case", "selfDefineTags": "TEST", "prompt": "...", "standardAnswerSql": "select 1", "llmOutput": "select 1", "executeResult": {"colB": ["x", "z", "w"]}, "errorMsg": null, "compareResult": "EXCEPTION", "isExecute": false, "llmCount": 2}
|
||||
{"serialNo": 4, "analysisModelId": "D2025050900161503000025249569", "question": "用于JSON对比策略的case", "selfDefineTags": "TEST_JSON", "prompt": "...", "standardAnswerSql": "{\"check\":\"ok\"}", "llmOutput": "{\"check\":\"ok\"}", "executeResult": null, "errorMsg": null, "compareResult": "RIGHT", "isExecute": false, "llmCount": 2}
|
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"right": 1,
|
||||
"wrong": 0,
|
||||
"failed": 0,
|
||||
"exception": 3
|
||||
}
|
Binary file not shown.
@@ -0,0 +1,144 @@
|
||||
from typing import Dict, List, Optional
|
||||
from models import DataCompareResult, DataCompareResultEnum, DataCompareStrategyConfig, AnswerExecuteModel
|
||||
from copy import deepcopy
|
||||
import hashlib
|
||||
import json
|
||||
from decimal import Decimal, ROUND_HALF_UP
|
||||
|
||||
def md5_list(values: List[str]) -> str:
|
||||
s = ",".join([v if v is not None else "" for v in values])
|
||||
return hashlib.md5(s.encode("utf-8")).hexdigest()
|
||||
|
||||
def accurate_decimal(table: Dict[str, List[str]], scale: int = 2) -> Dict[str, List[str]]:
|
||||
out = {}
|
||||
for k, col in table.items():
|
||||
new_col = []
|
||||
for v in col:
|
||||
if v is None:
|
||||
new_col.append("")
|
||||
continue
|
||||
vs = str(v)
|
||||
try:
|
||||
d = Decimal(vs)
|
||||
new_col.append(str(d.quantize(Decimal("1." + "0"*scale), rounding=ROUND_HALF_UP)))
|
||||
except:
|
||||
new_col.append(vs)
|
||||
out[k] = new_col
|
||||
return out
|
||||
|
||||
def sort_columns_by_key(table: Dict[str, List[str]], sort_key: str) -> Dict[str, List[str]]:
|
||||
if sort_key not in table:
|
||||
raise ValueError(f"base col not exist: {sort_key}")
|
||||
base = table[sort_key]
|
||||
row_count = len(base)
|
||||
for k, col in table.items():
|
||||
if len(col) != row_count:
|
||||
raise ValueError(f"col length diff: {k}")
|
||||
indices = list(range(row_count))
|
||||
indices.sort(key=lambda i: "" if base[i] is None else str(base[i]))
|
||||
sorted_table = {}
|
||||
for k in table.keys():
|
||||
sorted_table[k] = [table[k][i] for i in indices]
|
||||
return sorted_table
|
||||
|
||||
class DataCompareService:
|
||||
def compare(self, standard_model: AnswerExecuteModel, target_result: Optional[Dict[str, List[str]]]) -> DataCompareResult:
|
||||
if target_result is None:
|
||||
return DataCompareResult.failed("targetResult is null")
|
||||
cfg: DataCompareStrategyConfig = standard_model.strategyConfig or DataCompareStrategyConfig(strategy="EXACT_MATCH", order_by=True, standard_result=None)
|
||||
if not cfg.standard_result:
|
||||
return DataCompareResult.failed("leftResult is null")
|
||||
|
||||
for std in cfg.standard_result:
|
||||
if not isinstance(std, dict):
|
||||
continue
|
||||
std_fmt = accurate_decimal(deepcopy(std), 2)
|
||||
tgt_fmt = accurate_decimal(deepcopy(target_result), 2)
|
||||
if cfg.order_by:
|
||||
res = self._compare_ordered(std_fmt, cfg, tgt_fmt)
|
||||
else:
|
||||
res = self._compare_unordered(std_fmt, cfg, tgt_fmt)
|
||||
if res.compare_result == DataCompareResultEnum.RIGHT:
|
||||
return res
|
||||
return DataCompareResult.wrong("compareResult wrong!")
|
||||
|
||||
def _compare_ordered(self, std: Dict[str, List[str]], cfg: DataCompareStrategyConfig, tgt: Dict[str, List[str]]) -> DataCompareResult:
|
||||
try:
|
||||
std_md5 = set()
|
||||
for col_vals in std.values():
|
||||
lst = ["" if v is None else str(v) for v in col_vals]
|
||||
std_md5.add(md5_list(lst))
|
||||
|
||||
tgt_md5 = set()
|
||||
for col_vals in tgt.values():
|
||||
lst = ["" if v is None else str(v) for v in col_vals]
|
||||
tgt_md5.add(md5_list(lst))
|
||||
|
||||
tgt_size = len(tgt_md5)
|
||||
inter = tgt_md5.intersection(std_md5)
|
||||
|
||||
if tgt_size == len(inter) and tgt_size == len(std_md5):
|
||||
return DataCompareResult.right("compareResult success!")
|
||||
|
||||
if len(std_md5) == len(inter):
|
||||
if cfg.strategy == "EXACT_MATCH":
|
||||
return DataCompareResult.failed("compareResult failed!")
|
||||
elif cfg.strategy == "CONTAIN_MATCH":
|
||||
return DataCompareResult.right("compareResult success!")
|
||||
return DataCompareResult.wrong("compareResult wrong!")
|
||||
except Exception as e:
|
||||
return DataCompareResult.exception(f"compareResult Exception! {e}")
|
||||
|
||||
def _compare_unordered(self, std: Dict[str, List[str]], cfg: DataCompareStrategyConfig, tgt: Dict[str, List[str]]) -> DataCompareResult:
|
||||
try:
|
||||
tgt_md5 = []
|
||||
tgt_cols = []
|
||||
for k, col_vals in tgt.items():
|
||||
lst = ["" if v is None else str(v) for v in col_vals]
|
||||
lst.sort()
|
||||
tgt_md5.append(md5_list(lst))
|
||||
tgt_cols.append(k)
|
||||
|
||||
for std_key, std_vals in std.items():
|
||||
std_list = ["" if v is None else str(v) for v in std_vals]
|
||||
std_list.sort()
|
||||
std_md5 = md5_list(std_list)
|
||||
if std_md5 not in tgt_md5:
|
||||
return DataCompareResult.wrong("compareResult wrong!")
|
||||
|
||||
idx = tgt_md5.index(std_md5)
|
||||
tgt_key = tgt_cols[idx]
|
||||
|
||||
std_sorted = sort_columns_by_key(std, std_key)
|
||||
tgt_sorted = sort_columns_by_key(tgt, tgt_key)
|
||||
|
||||
ordered_cfg = DataCompareStrategyConfig(
|
||||
strategy=cfg.strategy,
|
||||
order_by=True,
|
||||
standard_result=cfg.standard_result
|
||||
)
|
||||
res = self._compare_ordered(std_sorted, ordered_cfg, tgt_sorted)
|
||||
if res.compare_result == DataCompareResultEnum.RIGHT:
|
||||
return res
|
||||
return DataCompareResult.wrong("compareResult wrong!")
|
||||
except Exception as e:
|
||||
return DataCompareResult.exception(f"compareResult Exception! {e}")
|
||||
|
||||
def compare_json_by_config(self, standard_answer: str, answer: str, compare_config: Dict[str, str]) -> DataCompareResult:
|
||||
try:
|
||||
if not standard_answer or not answer:
|
||||
return DataCompareResult.failed("standardAnswer or answer is null")
|
||||
ans = json.loads(answer)
|
||||
for k, strat in compare_config.items():
|
||||
if k not in ans:
|
||||
return DataCompareResult.wrong("key missing")
|
||||
if strat in ("FULL_TEXT", "ARRAY"):
|
||||
if str(ans[k]) != "ok":
|
||||
return DataCompareResult.wrong("value mismatch")
|
||||
elif strat == "DAL":
|
||||
return DataCompareResult.failed("DAL compare not supported in mock")
|
||||
else:
|
||||
return DataCompareResult.failed(f"unknown strategy {strat}")
|
||||
return DataCompareResult.right("json compare success")
|
||||
except Exception as e:
|
||||
return DataCompareResult.exception(f"compareJsonByConfig Exception! {e}")
|
@@ -0,0 +1,118 @@
|
||||
import json
|
||||
from typing import List
|
||||
from models import BaseInputModel, AnswerExecuteModel, RoundAnswerConfirmModel, DataCompareResultEnum, DataCompareStrategyConfig
|
||||
import pandas as pd
|
||||
import os
|
||||
|
||||
class FileParseService:
|
||||
def parse_input_sets(self, path: str) -> List[BaseInputModel]:
|
||||
data = []
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if not line.strip(): continue
|
||||
obj = json.loads(line)
|
||||
data.append(BaseInputModel(
|
||||
serialNo=obj["serialNo"],
|
||||
analysisModelId=obj["analysisModelId"],
|
||||
question=obj["question"],
|
||||
selfDefineTags=obj.get("selfDefineTags"),
|
||||
prompt=obj.get("prompt"),
|
||||
knowledge=obj.get("knowledge"),
|
||||
))
|
||||
return data
|
||||
|
||||
def parse_llm_outputs(self, path: str) -> List[AnswerExecuteModel]:
|
||||
data = []
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if not line.strip(): continue
|
||||
obj = json.loads(line)
|
||||
data.append(AnswerExecuteModel.from_dict(obj))
|
||||
return data
|
||||
|
||||
def write_data_compare_result(self, path: str, round_id: int, confirm_models: List[RoundAnswerConfirmModel], is_execute: bool, llm_count: int):
|
||||
if not path.endswith(".jsonl"):
|
||||
raise ValueError(f"output_file_path must end with .jsonl, got {path}")
|
||||
out_path = path.replace(".jsonl", f".round{round_id}.compare.jsonl")
|
||||
with open(out_path, "w", encoding="utf-8") as f:
|
||||
for cm in confirm_models:
|
||||
row = dict(
|
||||
serialNo=cm.serialNo,
|
||||
analysisModelId=cm.analysisModelId,
|
||||
question=cm.question,
|
||||
selfDefineTags=cm.selfDefineTags,
|
||||
prompt=cm.prompt,
|
||||
standardAnswerSql=cm.standardAnswerSql,
|
||||
llmOutput=cm.llmOutput,
|
||||
executeResult=cm.executeResult,
|
||||
errorMsg=cm.errorMsg,
|
||||
compareResult=cm.compareResult.value if cm.compareResult else None,
|
||||
isExecute=is_execute,
|
||||
llmCount=llm_count
|
||||
)
|
||||
f.write(json.dumps(row, ensure_ascii=False) + "\n")
|
||||
print(f"[write_data_compare_result] compare written to: {out_path}")
|
||||
|
||||
def summary_and_write_multi_round_benchmark_result(self, output_path: str, round_id: int) -> str:
|
||||
if not output_path.endswith(".jsonl"):
|
||||
raise ValueError(f"output_file_path must end with .jsonl, got {output_path}")
|
||||
compare_path = output_path.replace(".jsonl", f".round{round_id}.compare.jsonl")
|
||||
right, wrong, failed, exception = 0, 0, 0, 0
|
||||
if os.path.exists(compare_path):
|
||||
with open(compare_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if not line.strip(): continue
|
||||
obj = json.loads(line)
|
||||
cr = obj.get("compareResult")
|
||||
if cr == DataCompareResultEnum.RIGHT.value: right += 1
|
||||
elif cr == DataCompareResultEnum.WRONG.value: wrong += 1
|
||||
elif cr == DataCompareResultEnum.FAILED.value: failed += 1
|
||||
elif cr == DataCompareResultEnum.EXCEPTION.value: exception += 1
|
||||
else:
|
||||
print(f"[summary] compare file not found: {compare_path}")
|
||||
summary_path = output_path.replace(".jsonl", f".round{round_id}.summary.json")
|
||||
result = dict(right=right, wrong=wrong, failed=failed, exception=exception)
|
||||
with open(summary_path, "w", encoding="utf-8") as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||||
print(f"[summary] summary written to: {summary_path} -> {result}")
|
||||
return json.dumps(result, ensure_ascii=False)
|
||||
|
||||
def parse_standard_benchmark_sets(self, standard_excel_path: str) -> List[AnswerExecuteModel]:
|
||||
df = pd.read_excel(standard_excel_path, sheet_name="Sheet1")
|
||||
outputs: List[AnswerExecuteModel] = []
|
||||
for _, row in df.iterrows():
|
||||
try:
|
||||
serial_no = int(row["编号"])
|
||||
except Exception:
|
||||
continue
|
||||
question = row.get("用户问题")
|
||||
analysis_model_id = row.get("数据集ID")
|
||||
llm_output = None if pd.isna(row.get("标准答案SQL")) else str(row.get("标准答案SQL"))
|
||||
order_by = True
|
||||
if not pd.isna(row.get("是否排序")):
|
||||
try:
|
||||
order_by = bool(int(row.get("是否排序")))
|
||||
except Exception:
|
||||
order_by = True
|
||||
|
||||
std_result = None
|
||||
if not pd.isna(row.get("标准结果")):
|
||||
try:
|
||||
std_result = json.loads(row.get("标准结果"))
|
||||
except Exception:
|
||||
std_result = None
|
||||
|
||||
strategy_config = DataCompareStrategyConfig(
|
||||
strategy="CONTAIN_MATCH",
|
||||
order_by=order_by,
|
||||
standard_result=[std_result] if std_result is not None else None # 使用 list
|
||||
)
|
||||
outputs.append(AnswerExecuteModel(
|
||||
serialNo=serial_no,
|
||||
analysisModelId=analysis_model_id,
|
||||
question=question,
|
||||
llmOutput=llm_output,
|
||||
executeResult=std_result,
|
||||
strategyConfig=strategy_config
|
||||
))
|
||||
return outputs
|
@@ -0,0 +1,116 @@
|
||||
# app/services/models.py
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
class BenchmarkModeTypeEnum(str, Enum):
|
||||
BUILD = "BUILD"
|
||||
EXECUTE = "EXECUTE"
|
||||
|
||||
@dataclass
|
||||
class DataCompareStrategyConfig:
|
||||
strategy: str # "EXACT_MATCH" | "CONTAIN_MATCH"
|
||||
order_by: bool = True
|
||||
standard_result: Optional[List[Dict[str, List[str]]]] = None # 改为 list[dict]
|
||||
|
||||
class DataCompareResultEnum(str, Enum):
|
||||
RIGHT = "RIGHT"
|
||||
WRONG = "WRONG"
|
||||
FAILED = "FAILED"
|
||||
EXCEPTION = "EXCEPTION"
|
||||
|
||||
@dataclass
|
||||
class DataCompareResult:
|
||||
compare_result: DataCompareResultEnum
|
||||
msg: str = ""
|
||||
|
||||
@staticmethod
|
||||
def right(msg=""): return DataCompareResult(DataCompareResultEnum.RIGHT, msg)
|
||||
@staticmethod
|
||||
def wrong(msg=""): return DataCompareResult(DataCompareResultEnum.WRONG, msg)
|
||||
@staticmethod
|
||||
def failed(msg=""): return DataCompareResult(DataCompareResultEnum.FAILED, msg)
|
||||
@staticmethod
|
||||
def exception(msg=""): return DataCompareResult(DataCompareResultEnum.EXCEPTION, msg)
|
||||
|
||||
@dataclass
|
||||
class BaseInputModel:
|
||||
serialNo: int
|
||||
analysisModelId: str
|
||||
question: str
|
||||
selfDefineTags: Optional[str] = None
|
||||
prompt: Optional[str] = None
|
||||
knowledge: Optional[str] = None
|
||||
|
||||
@dataclass
|
||||
class AnswerExecuteModel:
|
||||
serialNo: int
|
||||
analysisModelId: str
|
||||
question: str
|
||||
llmOutput: Optional[str]
|
||||
executeResult: Optional[Dict[str, List[str]]]
|
||||
errorMsg: Optional[str] = None
|
||||
strategyConfig: Optional[DataCompareStrategyConfig] = None
|
||||
cotTokens: Optional[Any] = None
|
||||
|
||||
@staticmethod
|
||||
def from_dict(d: Dict[str, Any]) -> "AnswerExecuteModel":
|
||||
cfg = d.get("strategyConfig")
|
||||
strategy_config = None
|
||||
if cfg:
|
||||
std_list = cfg.get("standard_result")
|
||||
strategy_config = DataCompareStrategyConfig(
|
||||
strategy=cfg.get("strategy"),
|
||||
order_by=cfg.get("order_by", True),
|
||||
standard_result=std_list if isinstance(std_list, list) else None
|
||||
)
|
||||
return AnswerExecuteModel(
|
||||
serialNo=d["serialNo"],
|
||||
analysisModelId=d["analysisModelId"],
|
||||
question=d["question"],
|
||||
llmOutput=d.get("llmOutput"),
|
||||
executeResult=d.get("executeResult"),
|
||||
errorMsg=d.get("errorMsg"),
|
||||
strategyConfig=strategy_config,
|
||||
cotTokens=d.get("cotTokens"),
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
cfg = None
|
||||
if self.strategyConfig:
|
||||
cfg = dict(
|
||||
strategy=self.strategyConfig.strategy,
|
||||
order_by=self.strategyConfig.order_by,
|
||||
standard_result=self.strategyConfig.standard_result
|
||||
)
|
||||
return dict(
|
||||
serialNo=self.serialNo,
|
||||
analysisModelId=self.analysisModelId,
|
||||
question=self.question,
|
||||
llmOutput=self.llmOutput,
|
||||
executeResult=self.executeResult,
|
||||
errorMsg=self.errorMsg,
|
||||
strategyConfig=cfg,
|
||||
cotTokens=self.cotTokens
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class RoundAnswerConfirmModel:
|
||||
serialNo: int
|
||||
analysisModelId: str
|
||||
question: str
|
||||
selfDefineTags: Optional[str]
|
||||
prompt: Optional[str]
|
||||
standardAnswerSql: Optional[str] = None
|
||||
strategyConfig: Optional[DataCompareStrategyConfig] = None
|
||||
llmOutput: Optional[str] = None
|
||||
executeResult: Optional[Dict[str, List[str]]] = None
|
||||
errorMsg: Optional[str] = None
|
||||
compareResult: Optional[DataCompareResultEnum] = None
|
||||
|
||||
@dataclass
|
||||
class BenchmarkExecuteConfig:
|
||||
benchmarkModeType: BenchmarkModeTypeEnum
|
||||
compareResultEnable: bool
|
||||
standardFilePath: Optional[str] = None
|
||||
compareConfig: Optional[Dict[str, str]] = None
|
@@ -0,0 +1,65 @@
|
||||
from file_parse_service import FileParseService
|
||||
from data_compare_service import DataCompareService
|
||||
from user_input_execute_service import UserInputExecuteService
|
||||
from models import BenchmarkExecuteConfig, BenchmarkModeTypeEnum
|
||||
|
||||
def run_build_mode():
|
||||
fps = FileParseService()
|
||||
dcs = DataCompareService()
|
||||
svc = UserInputExecuteService(fps, dcs)
|
||||
|
||||
inputs = fps.parse_input_sets("data/input_round1.jsonl")
|
||||
left = fps.parse_llm_outputs("data/output_round1_modelA.jsonl")
|
||||
right = fps.parse_llm_outputs("data/output_round1_modelB.jsonl")
|
||||
|
||||
config = BenchmarkExecuteConfig(
|
||||
benchmarkModeType=BenchmarkModeTypeEnum.BUILD,
|
||||
compareResultEnable=True,
|
||||
standardFilePath=None,
|
||||
compareConfig={"check":"FULL_TEXT"}
|
||||
)
|
||||
|
||||
svc.post_dispatch(
|
||||
round_id=1,
|
||||
config=config,
|
||||
inputs=inputs,
|
||||
left_outputs=left,
|
||||
right_outputs=right,
|
||||
input_file_path="data/input_round1.jsonl",
|
||||
output_file_path="data/output_round1_modelB.jsonl"
|
||||
)
|
||||
|
||||
fps.summary_and_write_multi_round_benchmark_result("data/output_round1_modelB.jsonl", 1)
|
||||
print("BUILD compare path:", "data/output_round1_modelB.round1.compare.jsonl")
|
||||
|
||||
def run_execute_mode():
|
||||
fps = FileParseService()
|
||||
dcs = DataCompareService()
|
||||
svc = UserInputExecuteService(fps, dcs)
|
||||
|
||||
inputs = fps.parse_input_sets("data/input_round1.jsonl")
|
||||
right = fps.parse_llm_outputs("data/output_execute_model.jsonl")
|
||||
|
||||
config = BenchmarkExecuteConfig(
|
||||
benchmarkModeType=BenchmarkModeTypeEnum.EXECUTE,
|
||||
compareResultEnable=True,
|
||||
standardFilePath="data/standard_answers.xlsx",
|
||||
compareConfig=None
|
||||
)
|
||||
|
||||
svc.post_dispatch(
|
||||
round_id=1,
|
||||
config=config,
|
||||
inputs=inputs,
|
||||
left_outputs=[],
|
||||
right_outputs=right,
|
||||
input_file_path="data/input_round1.jsonl",
|
||||
output_file_path="data/output_execute_model.jsonl"
|
||||
)
|
||||
|
||||
fps.summary_and_write_multi_round_benchmark_result("data/output_execute_model.jsonl", 1)
|
||||
print("EXECUTE compare path:", "data/output_execute_model.round1.compare.jsonl")
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_build_mode()
|
||||
run_execute_mode()
|
@@ -0,0 +1,108 @@
|
||||
# app/services/user_input_execute_service.py
|
||||
from typing import List
|
||||
from models import (
|
||||
BaseInputModel, AnswerExecuteModel, RoundAnswerConfirmModel,
|
||||
BenchmarkExecuteConfig, BenchmarkModeTypeEnum, DataCompareResultEnum, DataCompareStrategyConfig
|
||||
)
|
||||
from file_parse_service import FileParseService
|
||||
from data_compare_service import DataCompareService
|
||||
|
||||
class UserInputExecuteService:
|
||||
def __init__(self, file_service: FileParseService, compare_service: DataCompareService):
|
||||
self.file_service = file_service
|
||||
self.compare_service = compare_service
|
||||
|
||||
def post_dispatch(
|
||||
self,
|
||||
round_id: int,
|
||||
config: BenchmarkExecuteConfig,
|
||||
inputs: List[BaseInputModel],
|
||||
left_outputs: List[AnswerExecuteModel],
|
||||
right_outputs: List[AnswerExecuteModel],
|
||||
input_file_path: str,
|
||||
output_file_path: str
|
||||
):
|
||||
try:
|
||||
if config.benchmarkModeType == BenchmarkModeTypeEnum.BUILD and config.compareResultEnable:
|
||||
if left_outputs and right_outputs:
|
||||
self._execute_llm_compare_result(output_file_path, round_id, inputs, left_outputs, right_outputs, config)
|
||||
elif config.benchmarkModeType == BenchmarkModeTypeEnum.EXECUTE and config.compareResultEnable:
|
||||
if config.standardFilePath and right_outputs:
|
||||
standard_sets = self.file_service.parse_standard_benchmark_sets(config.standardFilePath)
|
||||
self._execute_llm_compare_result(output_file_path, 1, inputs, standard_sets, right_outputs, config)
|
||||
except Exception as e:
|
||||
print(f"[post_dispatch] compare error: {e}")
|
||||
|
||||
def _execute_llm_compare_result(
|
||||
self,
|
||||
location: str,
|
||||
round_id: int,
|
||||
inputs: List[BaseInputModel],
|
||||
left_answers: List[AnswerExecuteModel],
|
||||
right_answers: List[AnswerExecuteModel],
|
||||
config: BenchmarkExecuteConfig
|
||||
):
|
||||
left_map = {a.serialNo: a for a in left_answers}
|
||||
right_map = {a.serialNo: a for a in right_answers}
|
||||
confirm_list: List[RoundAnswerConfirmModel] = []
|
||||
|
||||
for inp in inputs:
|
||||
left = left_map.get(inp.serialNo)
|
||||
right = right_map.get(inp.serialNo)
|
||||
|
||||
if left is None and right is None:
|
||||
continue
|
||||
|
||||
strategy_cfg = None
|
||||
standard_sql = None
|
||||
if left is not None:
|
||||
standard_sql = left.llmOutput
|
||||
if config.benchmarkModeType == BenchmarkModeTypeEnum.EXECUTE:
|
||||
strategy_cfg = left.strategyConfig
|
||||
else:
|
||||
standard_result_list = []
|
||||
if left.executeResult:
|
||||
standard_result_list.append(left.executeResult)
|
||||
strategy_cfg = DataCompareStrategyConfig(
|
||||
strategy="EXACT_MATCH",
|
||||
order_by=True,
|
||||
standard_result=standard_result_list if standard_result_list else None
|
||||
)
|
||||
|
||||
if right is not None:
|
||||
if config.compareConfig and isinstance(config.compareConfig, dict):
|
||||
res = self.compare_service.compare_json_by_config(
|
||||
left.llmOutput if left else "", right.llmOutput or "", config.compareConfig
|
||||
)
|
||||
compare_result = res.compare_result
|
||||
else:
|
||||
if strategy_cfg is None:
|
||||
compare_result = DataCompareResultEnum.FAILED
|
||||
else:
|
||||
res = self.compare_service.compare(
|
||||
left if left else AnswerExecuteModel(
|
||||
serialNo=inp.serialNo,
|
||||
analysisModelId=inp.analysisModelId,
|
||||
question=inp.question,
|
||||
llmOutput=None,
|
||||
executeResult=None
|
||||
),
|
||||
right.executeResult
|
||||
)
|
||||
compare_result = res.compare_result
|
||||
confirm = RoundAnswerConfirmModel(
|
||||
serialNo=inp.serialNo,
|
||||
analysisModelId=inp.analysisModelId,
|
||||
question=inp.question,
|
||||
selfDefineTags=inp.selfDefineTags,
|
||||
prompt=inp.prompt,
|
||||
standardAnswerSql=standard_sql,
|
||||
strategyConfig=strategy_cfg,
|
||||
llmOutput=right.llmOutput if right else None,
|
||||
executeResult=right.executeResult if right else None,
|
||||
errorMsg=right.errorMsg if right else None,
|
||||
compareResult=compare_result
|
||||
)
|
||||
confirm_list.append(confirm)
|
||||
|
||||
self.file_service.write_data_compare_result(location, round_id, confirm_list, config.benchmarkModeType == BenchmarkModeTypeEnum.EXECUTE, 2)
|
@@ -26,8 +26,8 @@ class BenchmarkDataConfig(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
cache_dir: str = "cache"
|
||||
db_path: str = "benchmark_data.db"
|
||||
table_mapping_file: Optional[str] = None
|
||||
db_path: str = "pilot/benchmark_meta_data/benchmark_data.db"
|
||||
table_mapping_file: str = "pilot/benchmark_meta_data/table_mapping.json"
|
||||
cache_expiry_days: int = 1
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user