mirror of
https://github.com/csunny/DB-GPT.git
synced 2026-01-13 19:55:44 +00:00
feat(benchmark): multi model post process
This commit is contained in:
@@ -326,10 +326,10 @@ class BenchmarkService(
|
||||
)
|
||||
|
||||
output_sets = BenchmarkDataSets[OutputType]()
|
||||
output_list = []
|
||||
output_list: List[OutputType] = []
|
||||
|
||||
written_batches = set() # 记录已写入批次
|
||||
complete_map = {} # 记录任务完成状态,使用Dict[int, OutputType]
|
||||
written_batches: set[int] = set()
|
||||
complete_map: Dict[int, OutputType] = {}
|
||||
|
||||
# 线程锁,保证线程安全
|
||||
lock = threading.Lock()
|
||||
@@ -356,7 +356,6 @@ class BenchmarkService(
|
||||
f" output={json.dumps(output.to_dict(), ensure_ascii=False)}"
|
||||
)
|
||||
|
||||
# 线程安全地添加结果
|
||||
with lock:
|
||||
output_list.append(output)
|
||||
|
||||
@@ -490,15 +489,16 @@ class BenchmarkService(
|
||||
input_file_path: str, output_file_path: str):
|
||||
"""
|
||||
Post dispatch processing standard result compare LLM execute result
|
||||
and write compare result to file
|
||||
and write compare result to file
|
||||
"""
|
||||
self.user_input_execute_service.post_dispatch(
|
||||
i,
|
||||
config,
|
||||
input_list,
|
||||
None,
|
||||
output_list[0].benchmark_data_sets.data_list,
|
||||
input_file_path,
|
||||
output_file_path,
|
||||
)
|
||||
for j, output_result in enumerate(output_list):
|
||||
self.user_input_execute_service.post_dispatch(
|
||||
i,
|
||||
config,
|
||||
input_list,
|
||||
None,
|
||||
output_result.benchmark_data_sets.data_list,
|
||||
input_file_path,
|
||||
output_file_path,
|
||||
)
|
||||
|
||||
|
||||
@@ -61,6 +61,7 @@ class AnswerExecuteModel:
|
||||
strategyConfig: Optional[DataCompareStrategyConfig] = None
|
||||
cotTokens: Optional[Any] = None
|
||||
cost_time: Optional[int] = None
|
||||
llm_code: Optional[str] = None
|
||||
|
||||
@staticmethod
|
||||
def from_dict(d: Dict[str, Any]) -> "AnswerExecuteModel":
|
||||
@@ -83,6 +84,7 @@ class AnswerExecuteModel:
|
||||
strategyConfig=strategy_config,
|
||||
cotTokens=d.get("cotTokens"),
|
||||
cost_time=d.get("cost_time"),
|
||||
llm_code=d.get("llm_code"),
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
@@ -103,6 +105,7 @@ class AnswerExecuteModel:
|
||||
strategyConfig=cfg,
|
||||
cotTokens=self.cotTokens,
|
||||
cost_time=self.cost_time,
|
||||
llm_code=self.llm_code,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -268,6 +268,7 @@ class UserInputExecuteService:
|
||||
executeResult=execute_result,
|
||||
cotTokens=response.cot_tokens,
|
||||
errorMsg=error_msg,
|
||||
llm_code=input.llm_code,
|
||||
)
|
||||
|
||||
def _extract_sql_content(self, content: str) -> str:
|
||||
|
||||
Reference in New Issue
Block a user