From 92c251d297f09298ca87ac609ae03aabdc4e6179 Mon Sep 17 00:00:00 2001 From: yaoyifan-yyf Date: Mon, 29 Sep 2025 11:37:19 +0800 Subject: [PATCH] opt: async load benchmark data on init --- .../src/dbgpt_app/component_configs.py | 9 + .../dbgpt-app/src/dbgpt_app/dbgpt_server.py | 45 -- .../service/benchmark/data_compare_service.py | 67 ++- .../service/benchmark/file_parse_service.py | 104 ++-- .../evaluate/service/benchmark/models.py | 32 +- .../evaluate/service/benchmark/run_demo.py | 25 +- .../benchmark/user_input_execute_service.py | 83 ++- .../fetchdata/benchmark_data_manager.py | 515 ++++++++++++------ 8 files changed, 572 insertions(+), 308 deletions(-) diff --git a/packages/dbgpt-app/src/dbgpt_app/component_configs.py b/packages/dbgpt-app/src/dbgpt_app/component_configs.py index 623b3e996..de2c75e1b 100644 --- a/packages/dbgpt-app/src/dbgpt_app/component_configs.py +++ b/packages/dbgpt-app/src/dbgpt_app/component_configs.py @@ -59,6 +59,7 @@ def initialize_components( _initialize_code_server(system_app) # Initialize prompt templates - MUST be after serve apps registration _initialize_prompt_templates() + _initialize_benchmark_data(system_app) def _initialize_model_cache(system_app: SystemApp, web_config: ServiceWebParameters): @@ -206,3 +207,11 @@ def _initialize_prompt_templates(): logger.error(f"Failed to initialize prompt templates: {e}") # Don't raise exception to avoid breaking the application startup # The templates will be loaded lazily when needed + + +def _initialize_benchmark_data(system_app: SystemApp): + from dbgpt_serve.evaluate.service.fetchdata.benchmark_data_manager import ( + initialize_benchmark_data, + ) + + initialize_benchmark_data(system_app) diff --git a/packages/dbgpt-app/src/dbgpt_app/dbgpt_server.py b/packages/dbgpt-app/src/dbgpt_app/dbgpt_server.py index cc9413a33..02223bfd1 100644 --- a/packages/dbgpt-app/src/dbgpt_app/dbgpt_server.py +++ b/packages/dbgpt-app/src/dbgpt_app/dbgpt_server.py @@ -1,4 +1,3 @@ -import asyncio import logging import os import sys @@ -36,9 +35,6 @@ from dbgpt_app.base import ( from dbgpt_app.component_configs import initialize_components from dbgpt_app.config import ApplicationConfig, ServiceWebParameters, SystemParameters from dbgpt_serve.core import add_exception_handler -from dbgpt_serve.evaluate.service.fetchdata.benchmark_data_manager import ( - get_benchmark_manager, -) logger = logging.getLogger(__name__) ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -148,13 +144,6 @@ def initialize_app(param: ApplicationConfig, args: List[str] = None): # After init, when the database is ready system_app.after_init() - # Async fetch benchmark dataset from Falcon - loop = asyncio.get_event_loop() - if loop.is_running(): - loop.create_task(load_benchmark_data()) - else: - loop.run_until_complete(load_benchmark_data()) - binding_port = web_config.port binding_host = web_config.host if not web_config.light: @@ -330,40 +319,6 @@ def parse_args(): return parser.parse_args() -async def load_benchmark_data(): - """Load benchmark data from GitHub repository into SQLite database""" - logging.basicConfig(level=logging.INFO) - logger.info("Starting benchmark data loading process...") - - try: - manager = get_benchmark_manager(system_app) - - async with manager: - logger.info("Fetching data from GitHub repository...") - result = await manager.load_from_github( - repo_url="https://github.com/inclusionAI/Falcon", data_dir="data/source" - ) - - # Log detailed results - logger.info("\nBenchmark Data Loading Summary:") - logger.info(f"Total CSV files processed: {result['total_files']}") - logger.info(f"Successfully imported: {result['successful']}") - logger.info(f"Failed imports: {result['failed']}") - - if result["failed"] > 0: - logger.warning(f"Encountered {result['failed']} failures during import") - - # Verify the loaded data - table_info = await manager.get_table_info() - logger.info(f"Loaded {len(table_info)} tables into database") - - return {"import_result": result, "table_info": table_info} - - except Exception as e: - logger.error("Failed to load benchmark data", exc_info=True) - raise RuntimeError(f"Benchmark data loading failed: {str(e)}") from e - - if __name__ == "__main__": # Parse command line arguments _args = parse_args() diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/data_compare_service.py b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/data_compare_service.py index 24553b008..1cb6f5144 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/data_compare_service.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/data_compare_service.py @@ -1,15 +1,25 @@ -from typing import Dict, List, Optional -from models import DataCompareResult, DataCompareResultEnum, DataCompareStrategyConfig, AnswerExecuteModel -from copy import deepcopy import hashlib import json -from decimal import Decimal, ROUND_HALF_UP +from copy import deepcopy +from decimal import ROUND_HALF_UP, Decimal +from typing import Dict, List, Optional + +from models import ( + AnswerExecuteModel, + DataCompareResult, + DataCompareResultEnum, + DataCompareStrategyConfig, +) + def md5_list(values: List[str]) -> str: s = ",".join([v if v is not None else "" for v in values]) return hashlib.md5(s.encode("utf-8")).hexdigest() -def accurate_decimal(table: Dict[str, List[str]], scale: int = 2) -> Dict[str, List[str]]: + +def accurate_decimal( + table: Dict[str, List[str]], scale: int = 2 +) -> Dict[str, List[str]]: out = {} for k, col in table.items(): new_col = [] @@ -20,13 +30,18 @@ def accurate_decimal(table: Dict[str, List[str]], scale: int = 2) -> Dict[str, L vs = str(v) try: d = Decimal(vs) - new_col.append(str(d.quantize(Decimal("1." + "0"*scale), rounding=ROUND_HALF_UP))) - except: + new_col.append( + str(d.quantize(Decimal("1." + "0" * scale), rounding=ROUND_HALF_UP)) + ) + except Exception as e: new_col.append(vs) out[k] = new_col return out -def sort_columns_by_key(table: Dict[str, List[str]], sort_key: str) -> Dict[str, List[str]]: + +def sort_columns_by_key( + table: Dict[str, List[str]], sort_key: str +) -> Dict[str, List[str]]: if sort_key not in table: raise ValueError(f"base col not exist: {sort_key}") base = table[sort_key] @@ -41,11 +56,21 @@ def sort_columns_by_key(table: Dict[str, List[str]], sort_key: str) -> Dict[str, sorted_table[k] = [table[k][i] for i in indices] return sorted_table + class DataCompareService: - def compare(self, standard_model: AnswerExecuteModel, target_result: Optional[Dict[str, List[str]]]) -> DataCompareResult: + def compare( + self, + standard_model: AnswerExecuteModel, + target_result: Optional[Dict[str, List[str]]], + ) -> DataCompareResult: if target_result is None: return DataCompareResult.failed("targetResult is null") - cfg: DataCompareStrategyConfig = standard_model.strategyConfig or DataCompareStrategyConfig(strategy="EXACT_MATCH", order_by=True, standard_result=None) + cfg: DataCompareStrategyConfig = ( + standard_model.strategyConfig + or DataCompareStrategyConfig( + strategy="EXACT_MATCH", order_by=True, standard_result=None + ) + ) if not cfg.standard_result: return DataCompareResult.failed("leftResult is null") @@ -62,7 +87,12 @@ class DataCompareService: return res return DataCompareResult.wrong("compareResult wrong!") - def _compare_ordered(self, std: Dict[str, List[str]], cfg: DataCompareStrategyConfig, tgt: Dict[str, List[str]]) -> DataCompareResult: + def _compare_ordered( + self, + std: Dict[str, List[str]], + cfg: DataCompareStrategyConfig, + tgt: Dict[str, List[str]], + ) -> DataCompareResult: try: std_md5 = set() for col_vals in std.values(): @@ -89,7 +119,12 @@ class DataCompareService: except Exception as e: return DataCompareResult.exception(f"compareResult Exception! {e}") - def _compare_unordered(self, std: Dict[str, List[str]], cfg: DataCompareStrategyConfig, tgt: Dict[str, List[str]]) -> DataCompareResult: + def _compare_unordered( + self, + std: Dict[str, List[str]], + cfg: DataCompareStrategyConfig, + tgt: Dict[str, List[str]], + ) -> DataCompareResult: try: tgt_md5 = [] tgt_cols = [] @@ -115,7 +150,7 @@ class DataCompareService: ordered_cfg = DataCompareStrategyConfig( strategy=cfg.strategy, order_by=True, - standard_result=cfg.standard_result + standard_result=cfg.standard_result, ) res = self._compare_ordered(std_sorted, ordered_cfg, tgt_sorted) if res.compare_result == DataCompareResultEnum.RIGHT: @@ -124,7 +159,9 @@ class DataCompareService: except Exception as e: return DataCompareResult.exception(f"compareResult Exception! {e}") - def compare_json_by_config(self, standard_answer: str, answer: str, compare_config: Dict[str, str]) -> DataCompareResult: + def compare_json_by_config( + self, standard_answer: str, answer: str, compare_config: Dict[str, str] + ) -> DataCompareResult: try: if not standard_answer or not answer: return DataCompareResult.failed("standardAnswer or answer is null") @@ -141,4 +178,4 @@ class DataCompareService: return DataCompareResult.failed(f"unknown strategy {strat}") return DataCompareResult.right("json compare success") except Exception as e: - return DataCompareResult.exception(f"compareJsonByConfig Exception! {e}") \ No newline at end of file + return DataCompareResult.exception(f"compareJsonByConfig Exception! {e}") diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/file_parse_service.py b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/file_parse_service.py index c0dd0ad9b..59ecc4b19 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/file_parse_service.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/file_parse_service.py @@ -1,36 +1,55 @@ import json -from typing import List -from models import BaseInputModel, AnswerExecuteModel, RoundAnswerConfirmModel, DataCompareResultEnum, DataCompareStrategyConfig -import pandas as pd import os +from typing import List + +import pandas as pd +from models import ( + AnswerExecuteModel, + BaseInputModel, + DataCompareResultEnum, + DataCompareStrategyConfig, + RoundAnswerConfirmModel, +) + class FileParseService: def parse_input_sets(self, path: str) -> List[BaseInputModel]: data = [] with open(path, "r", encoding="utf-8") as f: for line in f: - if not line.strip(): continue + if not line.strip(): + continue obj = json.loads(line) - data.append(BaseInputModel( - serialNo=obj["serialNo"], - analysisModelId=obj["analysisModelId"], - question=obj["question"], - selfDefineTags=obj.get("selfDefineTags"), - prompt=obj.get("prompt"), - knowledge=obj.get("knowledge"), - )) + data.append( + BaseInputModel( + serialNo=obj["serialNo"], + analysisModelId=obj["analysisModelId"], + question=obj["question"], + selfDefineTags=obj.get("selfDefineTags"), + prompt=obj.get("prompt"), + knowledge=obj.get("knowledge"), + ) + ) return data def parse_llm_outputs(self, path: str) -> List[AnswerExecuteModel]: data = [] with open(path, "r", encoding="utf-8") as f: for line in f: - if not line.strip(): continue + if not line.strip(): + continue obj = json.loads(line) data.append(AnswerExecuteModel.from_dict(obj)) return data - def write_data_compare_result(self, path: str, round_id: int, confirm_models: List[RoundAnswerConfirmModel], is_execute: bool, llm_count: int): + def write_data_compare_result( + self, + path: str, + round_id: int, + confirm_models: List[RoundAnswerConfirmModel], + is_execute: bool, + llm_count: int, + ): if not path.endswith(".jsonl"): raise ValueError(f"output_file_path must end with .jsonl, got {path}") out_path = path.replace(".jsonl", f".round{round_id}.compare.jsonl") @@ -48,26 +67,35 @@ class FileParseService: errorMsg=cm.errorMsg, compareResult=cm.compareResult.value if cm.compareResult else None, isExecute=is_execute, - llmCount=llm_count + llmCount=llm_count, ) f.write(json.dumps(row, ensure_ascii=False) + "\n") print(f"[write_data_compare_result] compare written to: {out_path}") - def summary_and_write_multi_round_benchmark_result(self, output_path: str, round_id: int) -> str: + def summary_and_write_multi_round_benchmark_result( + self, output_path: str, round_id: int + ) -> str: if not output_path.endswith(".jsonl"): - raise ValueError(f"output_file_path must end with .jsonl, got {output_path}") + raise ValueError( + f"output_file_path must end with .jsonl, got {output_path}" + ) compare_path = output_path.replace(".jsonl", f".round{round_id}.compare.jsonl") right, wrong, failed, exception = 0, 0, 0, 0 if os.path.exists(compare_path): with open(compare_path, "r", encoding="utf-8") as f: for line in f: - if not line.strip(): continue + if not line.strip(): + continue obj = json.loads(line) cr = obj.get("compareResult") - if cr == DataCompareResultEnum.RIGHT.value: right += 1 - elif cr == DataCompareResultEnum.WRONG.value: wrong += 1 - elif cr == DataCompareResultEnum.FAILED.value: failed += 1 - elif cr == DataCompareResultEnum.EXCEPTION.value: exception += 1 + if cr == DataCompareResultEnum.RIGHT.value: + right += 1 + elif cr == DataCompareResultEnum.WRONG.value: + wrong += 1 + elif cr == DataCompareResultEnum.FAILED.value: + failed += 1 + elif cr == DataCompareResultEnum.EXCEPTION.value: + exception += 1 else: print(f"[summary] compare file not found: {compare_path}") summary_path = output_path.replace(".jsonl", f".round{round_id}.summary.json") @@ -77,7 +105,9 @@ class FileParseService: print(f"[summary] summary written to: {summary_path} -> {result}") return json.dumps(result, ensure_ascii=False) - def parse_standard_benchmark_sets(self, standard_excel_path: str) -> List[AnswerExecuteModel]: + def parse_standard_benchmark_sets( + self, standard_excel_path: str + ) -> List[AnswerExecuteModel]: df = pd.read_excel(standard_excel_path, sheet_name="Sheet1") outputs: List[AnswerExecuteModel] = [] for _, row in df.iterrows(): @@ -87,7 +117,9 @@ class FileParseService: continue question = row.get("用户问题") analysis_model_id = row.get("数据集ID") - llm_output = None if pd.isna(row.get("标准答案SQL")) else str(row.get("标准答案SQL")) + llm_output = ( + None if pd.isna(row.get("标准答案SQL")) else str(row.get("标准答案SQL")) + ) order_by = True if not pd.isna(row.get("是否排序")): try: @@ -105,14 +137,18 @@ class FileParseService: strategy_config = DataCompareStrategyConfig( strategy="CONTAIN_MATCH", order_by=order_by, - standard_result=[std_result] if std_result is not None else None # 使用 list + standard_result=[std_result] + if std_result is not None + else None, # 使用 list ) - outputs.append(AnswerExecuteModel( - serialNo=serial_no, - analysisModelId=analysis_model_id, - question=question, - llmOutput=llm_output, - executeResult=std_result, - strategyConfig=strategy_config - )) - return outputs \ No newline at end of file + outputs.append( + AnswerExecuteModel( + serialNo=serial_no, + analysisModelId=analysis_model_id, + question=question, + llmOutput=llm_output, + executeResult=std_result, + strategyConfig=strategy_config, + ) + ) + return outputs diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/models.py b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/models.py index b606ca5d5..3809ea179 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/models.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/models.py @@ -1,37 +1,48 @@ -# app/services/models.py from dataclasses import dataclass from enum import Enum from typing import Any, Dict, List, Optional + class BenchmarkModeTypeEnum(str, Enum): BUILD = "BUILD" EXECUTE = "EXECUTE" + @dataclass class DataCompareStrategyConfig: strategy: str # "EXACT_MATCH" | "CONTAIN_MATCH" order_by: bool = True standard_result: Optional[List[Dict[str, List[str]]]] = None # 改为 list[dict] + class DataCompareResultEnum(str, Enum): RIGHT = "RIGHT" WRONG = "WRONG" FAILED = "FAILED" EXCEPTION = "EXCEPTION" + @dataclass class DataCompareResult: compare_result: DataCompareResultEnum msg: str = "" @staticmethod - def right(msg=""): return DataCompareResult(DataCompareResultEnum.RIGHT, msg) + def right(msg=""): + return DataCompareResult(DataCompareResultEnum.RIGHT, msg) + @staticmethod - def wrong(msg=""): return DataCompareResult(DataCompareResultEnum.WRONG, msg) + def wrong(msg=""): + return DataCompareResult(DataCompareResultEnum.WRONG, msg) + @staticmethod - def failed(msg=""): return DataCompareResult(DataCompareResultEnum.FAILED, msg) + def failed(msg=""): + return DataCompareResult(DataCompareResultEnum.FAILED, msg) + @staticmethod - def exception(msg=""): return DataCompareResult(DataCompareResultEnum.EXCEPTION, msg) + def exception(msg=""): + return DataCompareResult(DataCompareResultEnum.EXCEPTION, msg) + @dataclass class BaseInputModel: @@ -42,6 +53,7 @@ class BaseInputModel: prompt: Optional[str] = None knowledge: Optional[str] = None + @dataclass class AnswerExecuteModel: serialNo: int @@ -62,7 +74,7 @@ class AnswerExecuteModel: strategy_config = DataCompareStrategyConfig( strategy=cfg.get("strategy"), order_by=cfg.get("order_by", True), - standard_result=std_list if isinstance(std_list, list) else None + standard_result=std_list if isinstance(std_list, list) else None, ) return AnswerExecuteModel( serialNo=d["serialNo"], @@ -81,7 +93,7 @@ class AnswerExecuteModel: cfg = dict( strategy=self.strategyConfig.strategy, order_by=self.strategyConfig.order_by, - standard_result=self.strategyConfig.standard_result + standard_result=self.strategyConfig.standard_result, ) return dict( serialNo=self.serialNo, @@ -91,9 +103,10 @@ class AnswerExecuteModel: executeResult=self.executeResult, errorMsg=self.errorMsg, strategyConfig=cfg, - cotTokens=self.cotTokens + cotTokens=self.cotTokens, ) + @dataclass class RoundAnswerConfirmModel: serialNo: int @@ -108,9 +121,10 @@ class RoundAnswerConfirmModel: errorMsg: Optional[str] = None compareResult: Optional[DataCompareResultEnum] = None + @dataclass class BenchmarkExecuteConfig: benchmarkModeType: BenchmarkModeTypeEnum compareResultEnable: bool standardFilePath: Optional[str] = None - compareConfig: Optional[Dict[str, str]] = None \ No newline at end of file + compareConfig: Optional[Dict[str, str]] = None diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/run_demo.py b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/run_demo.py index f598a69b2..855e8688d 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/run_demo.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/run_demo.py @@ -1,7 +1,8 @@ -from file_parse_service import FileParseService from data_compare_service import DataCompareService -from user_input_execute_service import UserInputExecuteService +from file_parse_service import FileParseService from models import BenchmarkExecuteConfig, BenchmarkModeTypeEnum +from user_input_execute_service import UserInputExecuteService + def run_build_mode(): fps = FileParseService() @@ -16,7 +17,7 @@ def run_build_mode(): benchmarkModeType=BenchmarkModeTypeEnum.BUILD, compareResultEnable=True, standardFilePath=None, - compareConfig={"check":"FULL_TEXT"} + compareConfig={"check": "FULL_TEXT"}, ) svc.post_dispatch( @@ -26,12 +27,15 @@ def run_build_mode(): left_outputs=left, right_outputs=right, input_file_path="data/input_round1.jsonl", - output_file_path="data/output_round1_modelB.jsonl" + output_file_path="data/output_round1_modelB.jsonl", ) - fps.summary_and_write_multi_round_benchmark_result("data/output_round1_modelB.jsonl", 1) + fps.summary_and_write_multi_round_benchmark_result( + "data/output_round1_modelB.jsonl", 1 + ) print("BUILD compare path:", "data/output_round1_modelB.round1.compare.jsonl") + def run_execute_mode(): fps = FileParseService() dcs = DataCompareService() @@ -44,7 +48,7 @@ def run_execute_mode(): benchmarkModeType=BenchmarkModeTypeEnum.EXECUTE, compareResultEnable=True, standardFilePath="data/standard_answers.xlsx", - compareConfig=None + compareConfig=None, ) svc.post_dispatch( @@ -54,12 +58,15 @@ def run_execute_mode(): left_outputs=[], right_outputs=right, input_file_path="data/input_round1.jsonl", - output_file_path="data/output_execute_model.jsonl" + output_file_path="data/output_execute_model.jsonl", ) - fps.summary_and_write_multi_round_benchmark_result("data/output_execute_model.jsonl", 1) + fps.summary_and_write_multi_round_benchmark_result( + "data/output_execute_model.jsonl", 1 + ) print("EXECUTE compare path:", "data/output_execute_model.round1.compare.jsonl") + if __name__ == "__main__": run_build_mode() - run_execute_mode() \ No newline at end of file + run_execute_mode() diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/user_input_execute_service.py b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/user_input_execute_service.py index 1bba0034d..68fefb5f9 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/user_input_execute_service.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/user_input_execute_service.py @@ -1,14 +1,23 @@ # app/services/user_input_execute_service.py from typing import List -from models import ( - BaseInputModel, AnswerExecuteModel, RoundAnswerConfirmModel, - BenchmarkExecuteConfig, BenchmarkModeTypeEnum, DataCompareResultEnum, DataCompareStrategyConfig -) -from file_parse_service import FileParseService + from data_compare_service import DataCompareService +from file_parse_service import FileParseService +from models import ( + AnswerExecuteModel, + BaseInputModel, + BenchmarkExecuteConfig, + BenchmarkModeTypeEnum, + DataCompareResultEnum, + DataCompareStrategyConfig, + RoundAnswerConfirmModel, +) + class UserInputExecuteService: - def __init__(self, file_service: FileParseService, compare_service: DataCompareService): + def __init__( + self, file_service: FileParseService, compare_service: DataCompareService + ): self.file_service = file_service self.compare_service = compare_service @@ -20,16 +29,38 @@ class UserInputExecuteService: left_outputs: List[AnswerExecuteModel], right_outputs: List[AnswerExecuteModel], input_file_path: str, - output_file_path: str + output_file_path: str, ): try: - if config.benchmarkModeType == BenchmarkModeTypeEnum.BUILD and config.compareResultEnable: + if ( + config.benchmarkModeType == BenchmarkModeTypeEnum.BUILD + and config.compareResultEnable + ): if left_outputs and right_outputs: - self._execute_llm_compare_result(output_file_path, round_id, inputs, left_outputs, right_outputs, config) - elif config.benchmarkModeType == BenchmarkModeTypeEnum.EXECUTE and config.compareResultEnable: + self._execute_llm_compare_result( + output_file_path, + round_id, + inputs, + left_outputs, + right_outputs, + config, + ) + elif ( + config.benchmarkModeType == BenchmarkModeTypeEnum.EXECUTE + and config.compareResultEnable + ): if config.standardFilePath and right_outputs: - standard_sets = self.file_service.parse_standard_benchmark_sets(config.standardFilePath) - self._execute_llm_compare_result(output_file_path, 1, inputs, standard_sets, right_outputs, config) + standard_sets = self.file_service.parse_standard_benchmark_sets( + config.standardFilePath + ) + self._execute_llm_compare_result( + output_file_path, + 1, + inputs, + standard_sets, + right_outputs, + config, + ) except Exception as e: print(f"[post_dispatch] compare error: {e}") @@ -40,7 +71,7 @@ class UserInputExecuteService: inputs: List[BaseInputModel], left_answers: List[AnswerExecuteModel], right_answers: List[AnswerExecuteModel], - config: BenchmarkExecuteConfig + config: BenchmarkExecuteConfig, ): left_map = {a.serialNo: a for a in left_answers} right_map = {a.serialNo: a for a in right_answers} @@ -66,13 +97,17 @@ class UserInputExecuteService: strategy_cfg = DataCompareStrategyConfig( strategy="EXACT_MATCH", order_by=True, - standard_result=standard_result_list if standard_result_list else None + standard_result=standard_result_list + if standard_result_list + else None, ) if right is not None: if config.compareConfig and isinstance(config.compareConfig, dict): res = self.compare_service.compare_json_by_config( - left.llmOutput if left else "", right.llmOutput or "", config.compareConfig + left.llmOutput if left else "", + right.llmOutput or "", + config.compareConfig, ) compare_result = res.compare_result else: @@ -80,14 +115,16 @@ class UserInputExecuteService: compare_result = DataCompareResultEnum.FAILED else: res = self.compare_service.compare( - left if left else AnswerExecuteModel( + left + if left + else AnswerExecuteModel( serialNo=inp.serialNo, analysisModelId=inp.analysisModelId, question=inp.question, llmOutput=None, - executeResult=None + executeResult=None, ), - right.executeResult + right.executeResult, ) compare_result = res.compare_result confirm = RoundAnswerConfirmModel( @@ -101,8 +138,14 @@ class UserInputExecuteService: llmOutput=right.llmOutput if right else None, executeResult=right.executeResult if right else None, errorMsg=right.errorMsg if right else None, - compareResult=compare_result + compareResult=compare_result, ) confirm_list.append(confirm) - self.file_service.write_data_compare_result(location, round_id, confirm_list, config.benchmarkModeType == BenchmarkModeTypeEnum.EXECUTE, 2) \ No newline at end of file + self.file_service.write_data_compare_result( + location, + round_id, + confirm_list, + config.benchmarkModeType == BenchmarkModeTypeEnum.EXECUTE, + 2, + ) diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/fetchdata/benchmark_data_manager.py b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/fetchdata/benchmark_data_manager.py index a7c6bd867..141fbe3bc 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/fetchdata/benchmark_data_manager.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/fetchdata/benchmark_data_manager.py @@ -5,17 +5,18 @@ import json import logging import os import shutil -import sqlite3 import tempfile import time import zipfile from pathlib import Path -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple import aiohttp +from sqlalchemy import text from dbgpt._private.pydantic import BaseModel, ConfigDict from dbgpt.component import BaseComponent, ComponentType, SystemApp +from dbgpt_ext.datasource.rdbms.conn_sqlite import SQLiteConnector logger = logging.getLogger(__name__) @@ -26,9 +27,11 @@ class BenchmarkDataConfig(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) cache_dir: str = "cache" - db_path: str = "pilot/benchmark_meta_data/benchmark_data.db" + db_path: str = "pilot/benchmark_meta_data/ant_icube_dev.db" table_mapping_file: str = "pilot/benchmark_meta_data/table_mapping.json" cache_expiry_days: int = 1 + repo_url: str = "https://github.com/inclusionAI/Falcon" + data_dir: str = "data/source" class BenchmarkDataManager(BaseComponent): @@ -41,56 +44,174 @@ class BenchmarkDataManager(BaseComponent): ): super().__init__(system_app) self._config = config or BenchmarkDataConfig() - self._http_session = None - self._db_conn = None + self._http_session: Optional[aiohttp.ClientSession] = None + self._connector: Optional[SQLiteConnector] = None self._table_mappings = self._load_mappings() self._lock = asyncio.Lock() - self.temp_dir = None + self.temp_dir: Optional[str] = None # Ensure directories exist os.makedirs(self._config.cache_dir, exist_ok=True) + db_dir = os.path.dirname(self._config.db_path) + if db_dir: + os.makedirs(db_dir, exist_ok=True) + self._startup_loaded: bool = False def init_app(self, system_app: SystemApp): """Initialize the AgentManager.""" self.system_app = system_app + async def async_after_start(self): + logger.info("BenchmarkDataManager: async_after_start.") + + try: + if not self._config.repo_url: + logger.info("BenchmarkDataManager: repo_url not set, skip auto load.") + return + + if self._startup_loaded: + logger.info("BenchmarkDataManager: already loaded on startup, skip.") + return + + logger.info( + f"BenchmarkDataManager: auto loading repo {self._config.repo_url} " + f"dir={self._config.data_dir}" + ) + await get_benchmark_manager(self.system_app).load_from_github( + repo_url=self._config.repo_url, data_dir=self._config.data_dir + ) + self._startup_loaded = True + logger.info("BenchmarkDataManager: auto load finished.") + except Exception as e: + logger.error(f"BenchmarkDataManager: auto load failed: {e}") + + async def async_before_stop(self): + try: + logger.info("BenchmarkDataManager: closing resources before stop...") + await self.close() + logger.info("BenchmarkDataManager: close done.") + except Exception as e: + logger.warning(f"BenchmarkDataManager: close failed: {e}") + async def __aenter__(self): self._http_session = aiohttp.ClientSession() + await self.init_connector() return self async def __aexit__(self, exc_type, exc_val, exc_tb): await self.close() + async def init_connector(self): + """Initialize SQLiteConnector""" + async with self._lock: + if not self._connector: + self._connector = SQLiteConnector.from_file_path(self._config.db_path) + + async def close_connector(self): + """Close SQLiteConnector""" + async with self._lock: + if self._connector: + try: + self._connector.close() + except Exception as e: + logger.warning(f"Close connector failed: {e}") + self._connector = None + async def close(self): """Clean up resources""" if self._http_session: await self._http_session.close() self._http_session = None - if self._db_conn: - self._db_conn.close() - self._db_conn = None + await self.close_connector() self._cleanup_temp_dir() - async def get_connection(self) -> sqlite3.Connection: - """Get database connection (thread-safe)""" - async with self._lock: - if not self._db_conn: - self._db_conn = sqlite3.connect(self._config.db_path) - return self._db_conn + async def _run_in_thread(self, func, *args, **kwargs): + """Run blocking function in thread to avoid blocking event loop""" + return await asyncio.to_thread(func, *args, **kwargs) + + def _sanitize_column_name(self, name: str) -> str: + if name is None: + return "" + name = str(name).strip().strip('"').strip("'") + invalid_chars = [ + "-", + " ", + ".", + ",", + ";", + ":", + "!", + "?", + "'", + '"', + "(", + ")", + "[", + "]", + "{", + "}", + "\t", + "\r", + "\n", + "\x00", + ] + for ch in invalid_chars: + name = name.replace(ch, "_") + while "__" in name: + name = name.replace("__", "_") + if name and not (name[0].isalpha() or name[0] == "_"): + name = "_" + name + return name.lower() + + def _sanitize_and_dedup_headers(self, headers: List[str]) -> List[str]: + sanitized: List[str] = [] + used: set = set() + for idx, h in enumerate(headers): + name = self._sanitize_column_name(h) + if not name: + name = f"col_{idx}" + base = name + k = 2 + while name in used or not name: + name = f"{base}_{k}" + k += 1 + used.add(name) + sanitized.append(name) + return sanitized + + # ========================================================== + + # 通用查询(阻塞实现,在线程池中调用) + def _query_blocking(self, sql: str, params: Optional[Dict[str, Any]] = None): + assert self._connector is not None, "Connector not initialized" + with self._connector.session_scope() as session: + cursor = session.execute(text(sql), params or {}) + rows = cursor.fetchall() + # SQLAlchemy 2.0: cursor.keys() 提供列名 + cols = list(cursor.keys()) + return cols, rows + + # 通用写入(阻塞实现,在线程池中调用) + def _execute_blocking(self, sql: str, params: Optional[Dict[str, Any]] = None): + assert self._connector is not None, "Connector not initialized" + with self._connector.session_scope() as session: + result = session.execute(text(sql), params or {}) + session.commit() + return result.rowcount async def query(self, query: str, params: tuple = ()) -> List[Dict]: """Execute query and return results as dict list""" - conn = await self.get_connection() - cursor = conn.cursor() - cursor.execute(query, params) - columns = [col[0] for col in cursor.description] - return [dict(zip(columns, row)) for row in cursor.fetchall()] + await self.init_connector() + cols, rows = await self._run_in_thread(self._query_blocking, query, params) + return [dict(zip(cols, row)) for row in rows] async def load_from_github( self, repo_url: str, data_dir: str = "data/source" ) -> Dict: """Main method to load data from GitHub repository""" try: + await self.init_connector() + # 1. Download or use cached repository repo_dir = await self._download_repo_contents(repo_url) @@ -106,31 +227,38 @@ class BenchmarkDataManager(BaseComponent): except Exception as e: logger.error(f"Import failed: {str(e)}") - raise + raise RuntimeError(f"Benchmark data loading failed: {e}") from e finally: self._cleanup_temp_dir() async def get_table_info(self) -> Dict: """Get metadata about all tables""" - conn = await self.get_connection() - cursor = conn.cursor() + await self.init_connector() + assert self._connector is not None - cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") - tables = cursor.fetchall() + def _work(): + with self._connector.session_scope() as session: + tables = session.execute( + text("SELECT name FROM sqlite_master WHERE type='table'") + ) + tables = [row[0] for row in tables.fetchall()] + result: Dict[str, Any] = {} + for table_name in tables: + row_count = session.execute( + text(f'SELECT COUNT(*) FROM "{table_name}"') + ).fetchone()[0] + columns = session.execute( + text(f'PRAGMA table_info("{table_name}")') + ).fetchall() + result[table_name] = { + "row_count": row_count, + "columns": [ + {"name": col[1], "type": col[2]} for col in columns + ], + } + return result - result = {} - for table in tables: - table_name = table[0] - cursor.execute(f"SELECT COUNT(*) FROM {table_name}") - row_count = cursor.fetchone()[0] - cursor.execute(f"PRAGMA table_info({table_name})") - columns = cursor.fetchall() - - result[table_name] = { - "row_count": row_count, - "columns": [{"name": col[1], "type": col[2]} for col in columns], - } - return result + return await self._run_in_thread(_work) def clear_cache(self): """Clear cached repository files""" @@ -214,6 +342,8 @@ class BenchmarkDataManager(BaseComponent): logger.info(f"Downloading from GitHub repo: {zip_url}") try: + if self._http_session is None: + self._http_session = aiohttp.ClientSession() async with self._http_session.get(zip_url) as response: response.raise_for_status() zip_path = os.path.join(self.temp_dir, "repo.zip") @@ -233,7 +363,7 @@ class BenchmarkDataManager(BaseComponent): except Exception as e: self._cleanup_temp_dir() - raise RuntimeError(f"Failed to download repository: {str(e)}") + raise RuntimeError(f"Failed to download repository: {str(e)}") from e def _get_cache_path(self, repo_url: str) -> str: """Get path to cached zip file""" @@ -288,8 +418,8 @@ class BenchmarkDataManager(BaseComponent): async def _import_to_database(self, csv_files: List[Dict]) -> Dict: """Import CSV data to SQLite""" - conn = await self.get_connection() - cursor = conn.cursor() + await self.init_connector() + assert self._connector is not None results = { "total_files": len(csv_files), "successful": 0, @@ -297,13 +427,13 @@ class BenchmarkDataManager(BaseComponent): "tables_created": [], } - for file_info in csv_files: + def _process_one_file(file_info: Dict) -> Tuple[bool, Optional[str]]: + table_name = "" try: path_parts = [p for p in file_info["rel_path"].split(os.sep) if p] table_name = "_".join(path_parts + [Path(file_info["file_name"]).stem]) table_name = self._sanitize_table_name(table_name) - # Try multiple encodings encodings = ["utf-8-sig", "utf-8", "latin-1", "iso-8859-1", "cp1252"] for encoding in encodings: @@ -311,187 +441,220 @@ class BenchmarkDataManager(BaseComponent): with open(file_info["full_path"], "r", encoding=encoding) as f: content = f.read() - # Handle empty files - if not content.strip(): - raise ValueError("File is empty") + if not content.strip(): + raise ValueError("File is empty") - # Replace problematic line breaks if needed - content = content.replace("\r\n", "\n").replace("\r", "\n") + content = content.replace("\r\n", "\n").replace("\r", "\n") + lines = [line for line in content.split("\n") if line.strip()] + if not lines: + raise ValueError("No data after normalization") - # Split into lines - lines = [ - line for line in content.split("\n") if line.strip() - ] + header_line = lines[0] + data_line = lines[1] if len(lines) > 1 else "" + + try: + sample_for_sniff = "\n".join(lines[:10]) + sniffer = csv.Sniffer() + try: + dialect = sniffer.sniff(sample_for_sniff) + except Exception: + + class _DefaultDialect(csv.Dialect): + delimiter = "," + quotechar = '"' + doublequote = True + skipinitialspace = False + lineterminator = "\n" + quoting = csv.QUOTE_MINIMAL + + dialect = _DefaultDialect() try: - header_line = lines[0] - data_line = lines[1] if len(lines) > 1 else "" + has_header = sniffer.has_header("\n".join(lines[:50])) + except Exception: + has_header = True - # Detect delimiter (comma, semicolon, tab) - sniffer = csv.Sniffer() - dialect = sniffer.sniff(header_line) - has_header = sniffer.has_header(content[:1024]) + header_row = ( + list(csv.reader([header_line], dialect))[0] + if header_line + else [] + ) + first_data_row = ( + list(csv.reader([data_line], dialect))[0] + if data_line + else [] + ) - if has_header: - headers = list(csv.reader([header_line], dialect))[ - 0 - ] - first_data_row = ( - list(csv.reader([data_line], dialect))[0] - if data_line - else [] - ) - else: - headers = list(csv.reader([header_line], dialect))[ - 0 - ] - first_data_row = headers # first line is data - headers = [f"col_{i}" for i in range(len(headers))] + if not has_header: + num_cols_guess = len(header_row) + headers = [f"col_{i}" for i in range(num_cols_guess)] + first_data_row = header_row + else: + headers = header_row - # Determine actual number of columns from data - actual_columns = ( - len(first_data_row) - if first_data_row - else len(headers) + num_cols = ( + len(first_data_row) if first_data_row else len(headers) + ) + + # no header + if not headers or all( + (not str(h).strip()) for h in headers + ): + headers = [f"col_{i}" for i in range(num_cols or 1)] + + headers = self._sanitize_and_dedup_headers(headers) + + if num_cols <= 0: + num_cols = len(headers) + headers = headers[:num_cols] + if not headers or any( + h is None or h == "" for h in headers + ): + raise csv.Error("Invalid headers after sanitization") + + create_sql = f''' + CREATE TABLE IF NOT EXISTS "{table_name}" ( + {", ".join([f'"{h}" TEXT' for h in headers])} ) + ''' + insert_sql = f''' + INSERT INTO "{table_name}" ({ + ", ".join([f'"{h}"' for h in headers]) + }) + VALUES ({ + ", ".join([":" + f"p{i}" for i in range(len(headers))]) + }) + ''' - # Create table with correct number of columns - create_sql = f""" - CREATE TABLE IF NOT EXISTS {table_name} ({ - ", ".join( - [ - f'"{h}" TEXT' - for h in headers[:actual_columns] - ] - ) - }) - """ - cursor.execute(create_sql) + with self._connector.session_scope() as session: + logger.debug( + f"Table: {table_name}, headers(final): {headers}" + ) + session.execute(text(create_sql)) - # Prepare insert statement - insert_sql = f""" - INSERT INTO {table_name} VALUES ({ - ", ".join(["?"] * actual_columns) - }) - """ - - # Process data - batch = [] reader = csv.reader(lines, dialect) if has_header: - next(reader) # skip header + next(reader, None) + batch_params: List[Dict[str, Any]] = [] for row in reader: - if not row: # skip empty rows + if not row: continue - - # Ensure row has correct number of columns - if len(row) != actual_columns: - if len(row) < actual_columns: - row += [None] * (actual_columns - len(row)) + if len(row) != len(headers): + if len(row) < len(headers): + row += [None] * (len(headers) - len(row)) else: - row = row[:actual_columns] + row = row[: len(headers)] + params = { + f"p{i}": (row[i] if i < len(row) else None) + for i in range(len(headers)) + } + batch_params.append(params) + if len(batch_params) >= 1000: + session.execute(text(insert_sql), batch_params) + batch_params = [] + if batch_params: + session.execute(text(insert_sql), batch_params) + session.commit() - batch.append(row) - if len(batch) >= 1000: - cursor.executemany(insert_sql, batch) - batch = [] + return True, table_name - if batch: - cursor.executemany(insert_sql, batch) - - results["successful"] += 1 - results["tables_created"].append(table_name) - break - - except csv.Error as e: - # Fallback for malformed CSV files - self._import_with_simple_split( - cursor, table_name, content, results, file_info - ) - break + except csv.Error: + self._import_with_simple_split_blocking(table_name, content) + return True, table_name except UnicodeDecodeError: continue except Exception as e: logger.warning(f"Error with encoding {encoding}: {str(e)}") continue - else: - # All encodings failed - try binary mode as last resort - try: - with open(file_info["full_path"], "rb") as f: - content = f.read().decode("ascii", errors="ignore") - if content.strip(): - self._import_with_simple_split( - cursor, table_name, content, results, file_info - ) - else: - raise ValueError("File is empty or unreadable") - except Exception as e: - results["failed"] += 1 - logger.error( - f"Failed to process {file_info['file_name']}: {str(e)}" - ) + + try: + with open(file_info["full_path"], "rb") as f: + content = f.read().decode("ascii", errors="ignore") + if content.strip(): + self._import_with_simple_split_blocking(table_name, content) + return True, table_name + else: + raise ValueError("File is empty or unreadable") + except Exception as e: + return ( + False, + f"Failed to process {file_info['file_name']}: {str(e)}", + ) except Exception as e: - results["failed"] += 1 - logger.error(f"Failed to process {file_info['full_path']}: {str(e)}") + return ( + False, + f"Failed to process {file_info.get('full_path', '')}: {str(e)}", + ) + + for file_info in csv_files: + ok, info = await self._run_in_thread(_process_one_file, file_info) + if ok: + results["successful"] += 1 + if info: + results["tables_created"].append(info) + else: + results["failed"] += 1 + logger.error(info) - self._db_conn.commit() return results - def _import_with_simple_split( - self, cursor, table_name, content, results, file_info - ): - """Fallback method for malformed CSV files""" - # Normalize line endings + def _import_with_simple_split_blocking(self, table_name: str, content: str): + """Fallback method for malformed CSV files (blocking, 使用 SQLAlchemy 执行)""" + assert self._connector is not None content = content.replace("\r\n", "\n").replace("\r", "\n") lines = [line for line in content.split("\n") if line.strip()] - if not lines: raise ValueError("No data found after cleaning") - # Determine delimiter first_line = lines[0] delimiter = "," if "," in first_line else "\t" if "\t" in first_line else ";" - # Process header - headers = first_line.split(delimiter) + raw_headers = first_line.split(delimiter) + headers = self._sanitize_and_dedup_headers(raw_headers) actual_columns = len(headers) - # Create table create_sql = f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - {", ".join([f"col_{i} TEXT" for i in range(actual_columns)])} - ) + CREATE TABLE IF NOT EXISTS "{table_name}" ( + {", ".join([f'"{h}" TEXT' for h in headers])} + ) """ - cursor.execute(create_sql) - # Prepare insert insert_sql = f""" - INSERT INTO {table_name} VALUES ({", ".join(["?"] * actual_columns)}) + INSERT INTO "{table_name}" ({", ".join([f'"{h}"' for h in headers])}) + VALUES ({", ".join([":" + f"p{i}" for i in range(actual_columns)])}) """ - # Process data - batch = [] - for line in lines[1:]: # skip header - row = line.split(delimiter) - if len(row) != actual_columns: - if len(row) < actual_columns: - row += [None] * (actual_columns - len(row)) - else: - row = row[:actual_columns] - batch.append(row) + with self._connector.session_scope() as session: + session.execute(text(create_sql)) + batch: List[Dict[str, Any]] = [] + for line in lines[1:]: + row = line.split(delimiter) + if len(row) != actual_columns: + if len(row) < actual_columns: + row += [None] * (actual_columns - len(row)) + else: + row = row[:actual_columns] + params = {f"p{i}": row[i] for i in range(actual_columns)} + batch.append(params) + if len(batch) >= 1000: + session.execute(text(insert_sql), batch) + batch = [] + if batch: + session.execute(text(insert_sql), batch) + session.commit() - if len(batch) >= 1000: - cursor.executemany(insert_sql, batch) - batch = [] + async def get_table_info_simple(self) -> List[str]: + """Return simplified table info: table(column1,column2,...)""" + await self.init_connector() + assert self._connector is not None - if batch: - cursor.executemany(insert_sql, batch) + def _work(): + return list(self._connector.table_simple_info()) - results["successful"] += 1 - results["tables_created"].append(table_name) + return await self._run_in_thread(_work) def _cleanup_temp_dir(self): """Clean up temporary directory"""