feat(benchmark): multi model post process

This commit is contained in:
alan.cl
2025-10-13 14:38:45 +08:00
parent 9b81a10866
commit 87f11b574d
3 changed files with 18 additions and 14 deletions

View File

@@ -326,10 +326,10 @@ class BenchmarkService(
)
output_sets = BenchmarkDataSets[OutputType]()
output_list = []
output_list: List[OutputType] = []
written_batches = set() # 记录已写入批次
complete_map = {} # 记录任务完成状态使用Dict[int, OutputType]
written_batches: set[int] = set()
complete_map: Dict[int, OutputType] = {}
# 线程锁,保证线程安全
lock = threading.Lock()
@@ -356,7 +356,6 @@ class BenchmarkService(
f" output={json.dumps(output.to_dict(), ensure_ascii=False)}"
)
# 线程安全地添加结果
with lock:
output_list.append(output)
@@ -490,15 +489,16 @@ class BenchmarkService(
input_file_path: str, output_file_path: str):
"""
Post dispatch processing standard result compare LLM execute result
and write compare result to file
and write compare result to file
"""
self.user_input_execute_service.post_dispatch(
i,
config,
input_list,
None,
output_list[0].benchmark_data_sets.data_list,
input_file_path,
output_file_path,
)
for j, output_result in enumerate(output_list):
self.user_input_execute_service.post_dispatch(
i,
config,
input_list,
None,
output_result.benchmark_data_sets.data_list,
input_file_path,
output_file_path,
)

View File

@@ -61,6 +61,7 @@ class AnswerExecuteModel:
strategyConfig: Optional[DataCompareStrategyConfig] = None
cotTokens: Optional[Any] = None
cost_time: Optional[int] = None
llm_code: Optional[str] = None
@staticmethod
def from_dict(d: Dict[str, Any]) -> "AnswerExecuteModel":
@@ -83,6 +84,7 @@ class AnswerExecuteModel:
strategyConfig=strategy_config,
cotTokens=d.get("cotTokens"),
cost_time=d.get("cost_time"),
llm_code=d.get("llm_code"),
)
def to_dict(self) -> Dict[str, Any]:
@@ -103,6 +105,7 @@ class AnswerExecuteModel:
strategyConfig=cfg,
cotTokens=self.cotTokens,
cost_time=self.cost_time,
llm_code=self.llm_code,
)

View File

@@ -268,6 +268,7 @@ class UserInputExecuteService:
executeResult=execute_result,
cotTokens=response.cot_tokens,
errorMsg=error_msg,
llm_code=input.llm_code,
)
def _extract_sql_content(self, content: str) -> str: