diff --git a/packages/dbgpt-core/src/dbgpt/util/benchmarks/StorageUtil.py b/packages/dbgpt-core/src/dbgpt/util/benchmarks/StorageUtil.py index 4776fd002..8968789c8 100644 --- a/packages/dbgpt-core/src/dbgpt/util/benchmarks/StorageUtil.py +++ b/packages/dbgpt-core/src/dbgpt/util/benchmarks/StorageUtil.py @@ -10,6 +10,8 @@ class StorageUtil: YUQUE_URL_PREFIX = "https://yuque.com" + GITHUB_FALCON_PREFIX = "https://github.com/eosphoros-ai/Falcon" + @staticmethod def get_file_parse_type(file_path: Optional[str]) -> FileParseTypeEnum: """Get file parsing type based on file path @@ -28,5 +30,7 @@ class StorageUtil: if file_path.strip().startswith(StorageUtil.YUQUE_URL_PREFIX): return FileParseTypeEnum.YU_QUE + if file_path.strip().startswith(StorageUtil.GITHUB_FALCON_PREFIX): + return FileParseTypeEnum.GITHUB return FileParseTypeEnum.EXCEL diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/__init__.py b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/__init__.py index 185e84d9b..800c5afb2 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/__init__.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/__init__.py @@ -1,11 +1,5 @@ from .benchmark_service import BenchmarkService -from .data_compare_service import DataCompareService -from .file_parse_service import FileParseService -from .user_input_execute_service import UserInputExecuteService __all__ = [ "BenchmarkService", - "FileParseService", - "UserInputExecuteService", - "DataCompareService", ] diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/benchmark_service.py b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/benchmark_service.py index 410b351bb..6ca77169e 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/benchmark_service.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/benchmark_service.py @@ -18,6 +18,7 @@ from dbgpt.model import DefaultLLMClient from dbgpt.model.cluster import WorkerManagerFactory from dbgpt.storage.metadata import BaseDao from dbgpt.util import PaginationResult, get_or_create_event_loop +from dbgpt.util.benchmarks import StorageUtil from dbgpt_serve.evaluate.service.benchmark.task.benchmark_agent_task import ( BenchmarkAgentTask, ) @@ -40,7 +41,6 @@ from ...config import ServeConfig from ...models.models import ServeDao, ServeEntity from ..fetchdata.benchmark_data_manager import get_benchmark_manager from .data_compare_service import DataCompareService -from .ext.excel_file_parse import ExcelFileParseService from .models import ( BaseInputModel, BenchmarkDataSets, @@ -49,7 +49,6 @@ from .models import ( BenchmarkModeTypeEnum, BenchmarkTaskResult, ContentTypeEnum, - FileParseTypeEnum, InputType, OutputType, ) @@ -61,10 +60,7 @@ executor = ThreadPoolExecutor(max_workers=5) BENCHMARK_SERVICE_COMPONENT_NAME = "dbgpt_serve_evaluate_benchmark_service" -STANDARD_BENCHMARK_FILE_PATH = os.path.join( - BENCHMARK_DATA_ROOT_PATH, - "2025_07_27_public_500_standard_benchmark_question_list.xlsx", -) +STANDARD_BENCHMARK_FILE_PATH = "https://github.com/eosphoros-ai/Falcon" BENCHMARK_OUTPUT_RESULT_PATH = os.path.join(BENCHMARK_DATA_ROOT_PATH, "result") @@ -94,11 +90,10 @@ class BenchmarkService( super().__init__(system_app) self.rag_service = get_rag_service(system_app) self.prompt_service = get_prompt_service(system_app) - self._file_parse_type = FileParseTypeEnum.EXCEL + self._file_parse_type = StorageUtil.get_file_parse_type(STANDARD_BENCHMARK_FILE_PATH) - fps = ExcelFileParseService() dcs = DataCompareService() - self.user_input_execute_service = UserInputExecuteService(fps, dcs) + self.user_input_execute_service = UserInputExecuteService(dcs, self._file_parse_type) self.trigger_executor = ThreadPoolExecutor( max_workers=5, thread_name_prefix="benchmark-fileWrite" @@ -289,7 +284,7 @@ class BenchmarkService( await manager.load_data() logger.info( f"Benchmark dataset loaded from {manager._config.repo_url} " - f"dir={manager._config.data_dir}" + f"dir={manager._config.data_dirs}" ) except Exception as e: logger.error( diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/ext/__init__.py b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/ext/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/ext/excel_file_parse.py b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/ext/excel_file_parse.py new file mode 100644 index 000000000..618706756 --- /dev/null +++ b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/ext/excel_file_parse.py @@ -0,0 +1,176 @@ +import json +import logging +from pathlib import Path +from typing import Any, Dict, List, Optional + +import pandas as pd +from openpyxl import Workbook, load_workbook + +from dbgpt.util.benchmarks.ExcelUtils import ExcelUtils + +from ..file_parse_service import FileParseService +from ..models import ( + AnswerExecuteModel, + BaseInputModel, + BenchmarkDataSets, + DataCompareStrategyConfig, +) + +logger = logging.getLogger(__name__) + + +class ExcelFileParseService(FileParseService): + def parse_input_sets(self, path: str) -> BenchmarkDataSets: + """ + Parse input sets from excel file + Args: + path: File location path + Returns: + BenchmarkDataSets: Parsed data sets + """ + input_stream = self.get_input_stream(path) + + if input_stream is None: + raise RuntimeError(f"file not found! path: {path}") + + # Parse excel file to get data sets + input_sets = BenchmarkDataSets() + workbook = None + + try: + workbook = load_workbook(input_stream, data_only=True) + input_list = [] + + # Get the first worksheet + sheet = workbook.worksheets[0] + + for row_num in range( + 2, sheet.max_row + 1 + ): # Skip header row (start from 1-based index) + row = sheet[row_num] + if ExcelUtils.is_row_empty(row): + continue + + # Get content from columns 1-6 (0-based index 0-5) + serial_no_cell = row[0] + analysis_model_id_cell = row[1] + question_cell = row[2] + self_define_tags_cell = row[3] + knowledge_cell = row[4] + llm_output_cell = row[5] + prompt_cell = row[8] + + # Build input model + input_model = BaseInputModel( + serial_no=int( + ExcelUtils.get_cell_value_as_string(serial_no_cell) or "0" + ), + analysis_model_id=ExcelUtils.get_cell_value_as_string( + analysis_model_id_cell + ), + question=ExcelUtils.get_cell_value_as_string(question_cell), + self_define_tags=ExcelUtils.get_cell_value_as_string( + self_define_tags_cell + ), + llm_output=ExcelUtils.get_cell_value_as_string(llm_output_cell), + knowledge=ExcelUtils.get_cell_value_as_string(knowledge_cell), + prompt=ExcelUtils.get_cell_value_as_string(prompt_cell), + ) + + input_list.append(input_model) + + input_sets.data_list = input_list + except Exception as e: + logger.error(f"parse excel error, path: {path}, errorMsg: {e}") + finally: + try: + if workbook is not None: + workbook.close() + except Exception as e: + logger.error(f"close workbook error, path: {path}, errorMsg: {e}") + + return input_sets + + def parse_standard_benchmark_sets( + self, standard_excel_path: str + ) -> List[AnswerExecuteModel]: + df = pd.read_excel(standard_excel_path, sheet_name=0) + outputs: List[AnswerExecuteModel] = [] + for _, row in df.iterrows(): + try: + serial_no = int(row["编号"]) + except Exception: + continue + question = row.get("用户问题") + analysis_model_id = row.get("数据集ID") + llm_output = ( + None if pd.isna(row.get("标准答案SQL")) else str(row.get("标准答案SQL")) + ) + order_by = True + if not pd.isna(row.get("是否排序")): + try: + order_by = bool(int(row.get("是否排序"))) + except Exception: + order_by = True + + std_result: Optional[List[Dict[str, List[str]]]] = None + if not pd.isna(row.get("标准结果")): + std_result_raw = str(row.get("标准结果")) + std_result = self._parse_multi_standard_result(std_result_raw) + + strategy_config = DataCompareStrategyConfig( + strategy="CONTAIN_MATCH", + order_by=order_by, + standard_result=std_result if std_result is not None else None, + ) + outputs.append( + AnswerExecuteModel( + serialNo=serial_no, + analysisModelId=analysis_model_id, + question=question, + llmOutput=llm_output, + executeResult=std_result, + strategyConfig=strategy_config, + ) + ) + return outputs + + def _parse_multi_standard_result( + self, std_result_raw: str + ) -> Optional[List[Dict[str, List[str]]]]: + """ + Parse multiple standard results from raw string data. + + Handles multiple results separated by newlines and parses each line as a dict. + + Args: + std_result_raw (str): Raw standard result string with multiple lines + + Returns: + Optional[List[Dict[str, List[str]]]]: List of parsed dictionaries, + or None if parsing fails or no valid data + """ + try: + std_result_raw = std_result_raw.strip() + if not std_result_raw: + return None + + # 处理多个结果,通过换行符分隔 + result_lines = std_result_raw.split("\n") + result_list = [] + + for line in result_lines: + line = line.strip() + if line: + try: + result_list.append(json.loads(line)) + except Exception as e: + logger.warning( + f"Failed to parse line as JSON: {line}, error: {e}" + ) + continue + + return result_list if result_list else None + except Exception as e: + logger.error(f"parse multiple standard results error: {e}") + return None diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/ext/falcon_file_parse.py b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/ext/falcon_file_parse.py new file mode 100644 index 000000000..b1b26a10b --- /dev/null +++ b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/ext/falcon_file_parse.py @@ -0,0 +1,572 @@ +import asyncio +import logging +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple + +from dbgpt.util import get_or_create_event_loop +from dbgpt_serve.evaluate.service.benchmark.file_parse_service import FileParseService +from dbgpt_serve.evaluate.service.benchmark.models import ( + BaseInputModel, + BenchmarkDataSets, AnswerExecuteModel, DataCompareStrategyConfig, +) +from dbgpt_serve.evaluate.service.fetchdata.benchmark_data_manager import ( + FileLoadResult, + GoldenSqlListResult, + get_benchmark_manager, +) + +logger = logging.getLogger(__name__) + +TEXT_SQL_PROMPT = """ +Given the following dataset, including field names and sampling information: +{Schema} + +Containing the following knowledge: +{Knowledge} + +Based on the above information, please generate the SQL for the following question: +{Query} + +[Output Requirements] +* Field names in the generated SQL must use the actual field names from the table schema. +* Table names in the generated SQL must use the actual table names provided in the schema. +* Physical field names in the generated SQL must originate from the corresponding physical tables; generating fields that do not exist in the table is not allowed. +* Cartesian product calculations are not allowed in the generated SQL. This includes `CROSS JOIN`, `JOIN` operations missing `ON` or `USING` conditions, and multi-table joins without conditions set for all relationships between tables. +* The generated SQL must strictly adhere to {dialect} syntax. If it does not comply with this syntax, please regenerate it. +* Output only pure, executable SQL without any additional information. + +[Example] +** Table 1 Information +*** Table Name: orders +*** DDL Statement: + CREATE TABLE "orders" ( + "order_id" TEXT, + "customer_id" TEXT, + "item_type" TEXT, + "order_date" DATE + ) +*** Field Information: +|Field Name|Field Type|Sample Data| +|:--:|:--:|:--:| +|order_id|text|CN-2025-2562,CN-2025-8623,CN-2025-6535| +|customer_id|text|52351,56263,71252| +|item_type|text|办公用品,设备,家具| +|order_date|date|2023-02-21,2024-03-30,2024-12-20| + +User Question: 请帮我统计下各商品类型的订单数 +Final Output SQL: +SELECT + item_type, + COUNT(*) as order_count +FROM + orders +GROUP BY + item_type; +""" + + +@dataclass +class BenchmarkDataItem: + """benchmark data info item""" + + question_id: int + db_id: str + question: str + sql: str + answer: List[Dict[str, List[str]]] + is_order: str + + @staticmethod + def from_dict(data: dict) -> "BenchmarkDataItem": + return BenchmarkDataItem( + question_id=data.get("question_id", 0), + db_id=str(data.get("db_id", "")), + question=str(data.get("question", "")), + sql=str(data.get("SQL", "")), + answer=data.get("answer", []), + is_order=data.get("is_order", "0"), + ) + +@dataclass +class ColumnItem: + """column info Item""" + + column_id: int + column_name: str + column_type: str + sample_values: list + + @staticmethod + def from_dict(data: dict) -> "ColumnItem": + """从字典创建 ColumnItem 实例 + + Args: + data: 包含列信息的字典 + + Returns: + ColumnItem: 列信息实例 + """ + return ColumnItem( + column_id=data.get("column_id", 0), + column_name=data.get("column_name", ""), + column_type=data.get("column_type", ""), + sample_values=data.get("sample_values", []), + ) + + +@dataclass +class TableDDLItem: + """Table DDL Info Item""" + + table_id: int + table_name: str + columns: List[ColumnItem] + ddl: Optional[str] = None + + @staticmethod + def from_dict(data: dict) -> "TableDDLItem": + """从字典创建 TableDDLItem 实例 + + Args: + data: 包含表信息的字典 + + Returns: + TableDDLItem: 表信息实例 + """ + columns_data = data.get("columns", []) + columns = [ColumnItem.from_dict(col) for col in columns_data] + + return TableDDLItem( + table_id=data.get("table_id", 0), + table_name=data.get("table_name", ""), + columns=columns, + ddl=data.get("ddl"), + ) + + +@dataclass +class TableDataItem: + """Table Data Info Item""" + + db_id: str + table_ddl: List[TableDDLItem] + + @staticmethod + def from_dict(data: dict) -> "TableDataItem": + """从字典创建 TableDataItem 实例 + + Args: + data: 包含数据库表信息的字典 + + Returns: + TableDataItem: 数据库表信息实例 + """ + tables_data = data.get("tables", []) + table_ddl = [TableDDLItem.from_dict(table) for table in tables_data] + + return TableDataItem( + db_id=str(data.get("db_id", "")), + table_ddl=table_ddl, + ) + +class SafeDict(dict): + def __missing__(self, key): + return '{' + key + '}' + + +class FalconFileParseService(FileParseService): + def __init__(self): + super().__init__() + self._dev_data_file = "dev_data/dev.json" + self._dev_table_ddl_file = "dev_data/tables.json" + + self.benchmark_manager = get_benchmark_manager() + + self._dev_data: Optional[FileLoadResult] = None + self._dev_table_ddl: Optional[FileLoadResult] = None + self._data_loaded = False + + @staticmethod + def _format_answer_list(answer_list: List[Dict[str, List[str]]]) -> str: + """格式化 answer 列表为字符串 + + Args: + answer_list: 答案列表,每个元素是字典,字典的值是字符串列表 + + Returns: + str: JSON 格式的字符串,如果列表为空则返回空字符串 + """ + if not answer_list: + return "" + + try: + import json + # 将答案列表转换为 JSON 字符串,每个答案一行 + return "\n".join(json.dumps(item, ensure_ascii=False) for item in answer_list) + except Exception as e: + logger.warning(f"Failed to format answer list: {e}") + return str(answer_list) + + def _ensure_data_loaded(self): + if self._data_loaded: + return + logger.info("Loading benchmark data for the first time...") + try: + self._dev_data, self._dev_table_ddl = self._load_data_sync() + self._data_loaded = True + logger.info("Benchmark data loaded successfully") + except Exception as e: + logger.error(f"Failed to load benchmark data: {e}", exc_info=True) + raise RuntimeError(f"Failed to load benchmark data: {e}") + + + def parse_input_sets(self, path: str) -> BenchmarkDataSets: + """ + Parse input sets from github repo + Args: + path: File URL path + Returns: + BenchmarkDataSets: Parsed data sets + """ + self._ensure_data_loaded() + + try: + # 1. 解析评测数据 + benchmark_data_list = self._parse_benchmark_data(self._dev_data) + if not benchmark_data_list: + logger.error("Failed to parse benchmark data") + return BenchmarkDataSets(data_list=[]) + + # 2. 解析表结构 + table_ddl_data_list = self._parse_table_ddl_data(self._dev_table_ddl) + if not table_ddl_data_list: + logger.error("Failed to parse talbe ddl data") + return BenchmarkDataSets(data_list=[]) + table_ddl_data_map = {x.db_id: x.table_ddl for x in table_ddl_data_list} + + # 3. 将问题数据转换为 BaseInputModel 列表,并关联标准答案 + input_models = [] + for idx, question_item in enumerate(benchmark_data_list, start=1): + input_model = BaseInputModel( + serial_no=question_item.question_id, + analysis_model_id=question_item.db_id, + question=question_item.question, + self_define_tags="", + knowledge="", + llm_output=self._format_answer_list(question_item.answer), + prompt=self.load_benchmark_prompt_template(question_item, table_ddl_data_map.get(question_item.db_id)), + ) + input_models.append(input_model) + logger.info(f"Successfully parsed {len(input_models)} question items") + return BenchmarkDataSets(data_list=input_models) + except Exception as e: + logger.error( + f"load remote benchmark data error, error: {str(e)}", + exc_info=True, + ) + return BenchmarkDataSets(data_list=[]) + + def parse_standard_benchmark_sets( + self, standard_excel_path: str + ) -> List[AnswerExecuteModel]: + + self._ensure_data_loaded() + + outputs: List[AnswerExecuteModel] = [] + # 1. 解析评测数据 + benchmark_data_list = self._parse_benchmark_data(self._dev_data) + if not benchmark_data_list: + logger.error("Failed to parse benchmark data") + return outputs + + for idx, question_item in enumerate(benchmark_data_list, start=1): + serial_no = question_item.question_id + question = question_item.question + analysis_model_id = question_item.db_id + llm_output = question_item.sql + order_by = True + if question_item.is_order: + try: + order_by = bool(int(question_item.is_order)) + except Exception: + order_by = True + + std_result: Optional[List[Dict[str, List[str]]]] = None + if question_item.answer: + std_result = self._parse_multi_standard_result(question_item.answer) + + strategy_config = DataCompareStrategyConfig( + strategy="CONTAIN_MATCH", + order_by=order_by, + standard_result=std_result if std_result is not None else None, + ) + outputs.append( + AnswerExecuteModel( + serialNo=serial_no, + analysisModelId=analysis_model_id, + question=question, + llmOutput=llm_output, + executeResult=std_result, + strategyConfig=strategy_config, + ) + ) + return outputs + + def _parse_benchmark_data( + self, benchmark_data: Optional[FileLoadResult] + ) -> Optional[List[BenchmarkDataItem]]: + """ + 解析问题数据 + Args: + benchmark_data: 从 GitHub 加载的问题文件数据 + Returns: + List[BenchmarkDataItem]: 问题数据列表,如果解析失败返回 None + """ + if not benchmark_data or not benchmark_data.rows: + return None + + if benchmark_data.failed_count > 0: + logger.warning( + f"Question data has {benchmark_data.failed_count} failed rows" + ) + + benchmark_data_list = [] + for row in benchmark_data.rows: + if not isinstance(row.data, dict): + logger.warning( + f"Row {row.line_no} data is not a dict: {type(row.data)}" + ) + continue + + benchmark_data_item = BenchmarkDataItem.from_dict(row.data) + + if ( + not benchmark_data_item.question_id + or not benchmark_data_item.question + or not benchmark_data_item.db_id + ): + logger.warning( + f"Row {row.line_no} missing required fields: " + f"question_id={benchmark_data_item.question_id}, " + f"question={benchmark_data_item.question}, " + f"db_id={benchmark_data_item.db_id}" + ) + continue + + benchmark_data_list.append(benchmark_data_item) + + if not benchmark_data_list: + logger.error("No valid benchmark data parsed") + return None + + logger.info( + f"Successfully parsed {len(benchmark_data_list)} benchmark data" + f" from {len(benchmark_data.rows)} rows" + ) + # TODO 临时只取前 100 条数据 + return benchmark_data_list[:100] + + + def _parse_table_ddl_data( + self, table_ddl_data: Optional[FileLoadResult] + ) -> Optional[List[TableDataItem]]: + """ + 解析表 DDL 数据 + Args: + table_ddl_data: 从 GitHub 加载的表 DDL 数据 + Returns: + List[TableDataItem]: 表 DDL 数据列表,如果解析失败返回 None + """ + if not table_ddl_data or not table_ddl_data.rows: + logger.warning("table ddl data is None") + return None + if table_ddl_data.failed_count > 0: + logger.warning( + f"table ddl data has {table_ddl_data.failed_count} failed items" + ) + + table_ddl_data_list = [] + for row in table_ddl_data.rows: + if not isinstance(row.data, dict): + logger.warning( + f"Row {row.line_no} data is not a dict: {type(row.data)}" + ) + continue + + table_ddl_data_item = TableDataItem.from_dict(row.data) + + if ( + not table_ddl_data_item.db_id + or not table_ddl_data_item.table_ddl + ): + logger.warning( + f"Row {row.line_no} missing required fields: " + f"db_id={table_ddl_data_item.db_id}, " + f"table_ddl={table_ddl_data_item.table_ddl}" + ) + continue + + table_ddl_data_list.append(table_ddl_data_item) + + if not table_ddl_data_list: + return None + + logger.info( + f"Successfully parsed {len(table_ddl_data_list)} table DDL data" + f" from {len(table_ddl_data.rows)} rows" + ) + return table_ddl_data_list + + + async def _async_load_data( + self, + ) -> Tuple[ + Optional[FileLoadResult], + Optional[FileLoadResult], + ]: + """并发加载两个文件数据 + + 使用 asyncio.gather 并发执行两个异步任务,提高加载效率 + + Returns: + Tuple: (dev_data, dev_table_ddl) + """ + dev_data, dev_table_ddl = await asyncio.gather( + self.benchmark_manager.load_file_from_github(self._dev_data_file), + self.benchmark_manager.load_file_from_github(self._dev_table_ddl_file), + ) + return dev_data, dev_table_ddl + + def _load_data_sync( + self, + ) -> Tuple[ + Optional[FileLoadResult], + Optional[FileLoadResult], + ]: + """在同步上下文中加载数据 + + 智能检测当前事件循环状态: + - 如果事件循环正在运行,使用线程池在新线程中执行异步代码 + - 如果没有运行中的事件循环,直接使用 run_until_complete + + Returns: + Tuple: (dev_data, dev_table_ddl) + """ + try: + # 尝试获取当前运行中的事件循环 + asyncio.get_running_loop() + # 如果到这里说明有运行中的循环,使用线程池执行 + import concurrent.futures + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(asyncio.run, self._async_load_data()) + return future.result() + except RuntimeError: + # 没有运行中的循环,正常执行 + loop = get_or_create_event_loop() + return loop.run_until_complete(self._async_load_data()) + + def _run_in_new_loop( + self, + ) -> Tuple[ + Optional[FileLoadResult], + Optional[FileLoadResult], + ]: + """ + 在新的事件循环中运行异步任务 + + Returns: + Tuple: (dev_data, dev_table_ddl) + """ + new_loop = asyncio.new_event_loop() + asyncio.set_event_loop(new_loop) + try: + return new_loop.run_until_complete(self._async_load_data()) + finally: + new_loop.close() + + @staticmethod + def _build_table_schema_info(table_ddl_list: Optional[List[TableDDLItem]]) -> str: + """构建表的 Schema 信息 + + Args: + table_ddl_list: 表 DDL 信息列表 + + Returns: + str: 格式化的表 Schema 信息字符串 + """ + if not table_ddl_list: + return "" + + schema_parts = [] + + for idx, table in enumerate(table_ddl_list, start=1): + # 表头信息 + schema_parts.append(f"** Table {idx} Information") + schema_parts.append(f"*** Table Name: {table.table_name}") + + # 如果有 DDL,添加 DDL 信息 + if table.ddl: + schema_parts.append(f"*** DDL Statement:") + # DDL 可能是多行的,需要缩进处理 + ddl_lines = table.ddl.strip().split('\n') + for ddl_line in ddl_lines: + schema_parts.append(f" {ddl_line}") + + # 列信息表头 - 使用更清晰的格式 + schema_parts.append("*** Field Information:") + schema_parts.append("|Field Name|Field Type|Sample Data|") + schema_parts.append("|:--:|:--:|:--:|") + + # 添加每一列的信息 + for column in table.columns: + # 格式化样本值 - 限制在合理长度内 + if column.sample_values: + # 取前3-5个样本值,用逗号连接,避免过长 + sample_count = min(3, len(column.sample_values)) + sample_str = ",".join(str(val) for val in column.sample_values[:sample_count]) + # 如果样本值过长,截断 + if len(sample_str) > 100: + sample_str = sample_str[:97] + "..." + else: + sample_str = "-" + + schema_parts.append(f"|{column.column_name}|{column.column_type}|{sample_str}|") + + # 表之间添加空行分隔 + if idx < len(table_ddl_list): + schema_parts.append("") + + return "\n".join(schema_parts) + + def _parse_multi_standard_result( + self, answer_list: List[Dict[str, List[str]]] + ) -> Optional[List[Dict[str, List[str]]]]: + """ + 解析标准答案结果 + + Args: + answer_list: 答案列表,已经是正确的格式 + + Returns: + Optional[List[Dict[str, List[str]]]]: 返回答案列表,如果为空返回 None + """ + try: + if not answer_list: + return None + return answer_list if answer_list else None + except Exception as e: + logger.error(f"parse standard results error: {e}") + return None + + def load_benchmark_prompt_template(self, question_item: BenchmarkDataItem, table_ddl: List[TableDDLItem]) -> str: + """ + build benchmark prompt template + """ + schema = self._build_table_schema_info(table_ddl) + format_params = { + "Schema": schema, + "Knowledge": "", + "Query": question_item.question + } + # 使用 SafeDict 和 format_map 实现非严格模式,缺失的变量不会报错 + return TEXT_SQL_PROMPT.format_map(SafeDict(format_params)) \ No newline at end of file diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/file_parse_service.py b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/file_parse_service.py index 27bc8f881..0498b7211 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/file_parse_service.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/file_parse_service.py @@ -4,20 +4,16 @@ import logging import os from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import List, Optional, Any, Dict import pandas as pd from openpyxl import Workbook, load_workbook -from dbgpt.util.benchmarks.ExcelUtils import ExcelUtils -from dbgpt_serve.evaluate.db.benchmark_db import BenchmarkResultDao - from .models import ( AnswerExecuteModel, BaseInputModel, BenchmarkDataSets, BenchmarkExecuteConfig, - DataCompareStrategyConfig, OutputType, RoundAnswerConfirmModel, ) @@ -27,8 +23,6 @@ logger = logging.getLogger(__name__) class FileParseService(ABC): def __init__(self): - self._benchmark_dao = BenchmarkResultDao() - # export column configuration file path self._column_config_file_path = os.path.join( os.path.dirname(__file__), @@ -56,7 +50,6 @@ class FileParseService(ABC): data.append(AnswerExecuteModel.from_dict(obj)) return data - @abstractmethod def write_data_compare_result( self, path: str, @@ -65,15 +58,100 @@ class FileParseService(ABC): is_execute: bool, llm_count: int, ): - """Write compare results to File + """Write compare results to an Excel file - Args: - path: Output file path - round_id: Round ID - confirm_models: List of answer confirm models - is_execute: Whether to execute the comparison - llm_count: LLM count + The output Excel file will be named as '.xlsx' and + sheet name is 'benchmark_compare_result'. If the file exists, it will + append rows; otherwise it will create a new file with headers. """ + try: + # Ensure output directory exists + output_dir = Path(path).parent + output_dir.mkdir(parents=True, exist_ok=True) + + output_file = path + + headers = [ + "serialNo", + "analysisModelId", + "question", + "selfDefineTags", + "prompt", + "standardAnswerSql", + "standardAnswer", + "llmCode", + "llmOutput", + "executeResult", + "errorMsg", + "compareResult", + ] + + # Load or create workbook and sheet + if Path(output_file).exists(): + workbook = load_workbook(str(output_file)) + if "benchmark_compare_result" in workbook.sheetnames: + worksheet = workbook["benchmark_compare_result"] + else: + worksheet = workbook.create_sheet("benchmark_compare_result") + # Write headers if new sheet + for col_idx, header in enumerate(headers, 1): + worksheet.cell(row=1, column=col_idx, value=header) + else: + workbook = Workbook() + worksheet = workbook.active + worksheet.title = "benchmark_compare_result" + # Write headers + for col_idx, header in enumerate(headers, 1): + worksheet.cell(row=1, column=col_idx, value=header) + + # Determine start row to append + start_row = worksheet.max_row + 1 if worksheet.max_row else 2 + + # Append rows + for idx, cm in enumerate(confirm_models): + row_data = [ + cm.serialNo, + cm.analysisModelId, + cm.question, + cm.selfDefineTags, + cm.prompt, + cm.standardAnswerSql, + self._format_set_result(cm.strategyConfig.standard_result) + if cm.strategyConfig is not None + else "", + cm.llmCode, + cm.llmOutput, + json.dumps(cm.executeResult, ensure_ascii=False) + if cm.executeResult is not None + else "", + cm.errorMsg, + cm.compareResult.value if cm.compareResult else None, + ] + for col_idx, value in enumerate(row_data, 1): + worksheet.cell(row=start_row + idx, column=col_idx, value=value) + + # Autosize columns (simple strategy) + for column in worksheet.columns: + max_length = 0 + column_letter = column[0].column_letter + for cell in column: + try: + if cell.value and len(str(cell.value)) > max_length: + max_length = len(str(cell.value)) + except Exception: + pass + adjusted_width = min(max(max_length + 2, 10), 80) + worksheet.column_dimensions[column_letter].width = adjusted_width + + workbook.save(str(output_file)) + workbook.close() + logger.info( + f"[write_data_compare_result] compare written to Excel: {output_file}" + ) + except Exception as e: + logger.error( + f"[write_data_compare_result] write excel error for path={path}: {e}" + ) def summary_and_write_multi_round_benchmark_result( self, output_path: str, round_id: int @@ -163,7 +241,6 @@ class FileParseService(ABC): """ pass - @abstractmethod def write_multi_round_benchmark_result( self, output_file_path: str, @@ -186,149 +263,6 @@ class FileParseService(ABC): start_index: Starting index (batch start row index) offset: Offset(file rows offset) """ - - -class ExcelFileParseService(FileParseService): - def parse_input_sets(self, path: str) -> BenchmarkDataSets: - """ - Parse input sets from excel file - Args: - path: File location path - Returns: - BenchmarkDataSets: Parsed data sets - """ - input_stream = self.get_input_stream(path) - - if input_stream is None: - raise RuntimeError(f"file not found! path: {path}") - - # Parse excel file to get data sets - input_sets = BenchmarkDataSets() - workbook = None - - try: - workbook = load_workbook(input_stream, data_only=True) - input_list = [] - - # Get the first worksheet - sheet = workbook.worksheets[0] - - for row_num in range( - 2, sheet.max_row + 1 - ): # Skip header row (start from 1-based index) - row = sheet[row_num] - if ExcelUtils.is_row_empty(row): - continue - - # Get content from columns 1-6 (0-based index 0-5) - serial_no_cell = row[0] - analysis_model_id_cell = row[1] - question_cell = row[2] - self_define_tags_cell = row[3] - knowledge_cell = row[4] - llm_output_cell = row[5] - prompt_cell = row[8] - - # Build input model - input_model = BaseInputModel( - serial_no=int( - ExcelUtils.get_cell_value_as_string(serial_no_cell) or "0" - ), - analysis_model_id=ExcelUtils.get_cell_value_as_string( - analysis_model_id_cell - ), - question=ExcelUtils.get_cell_value_as_string(question_cell), - self_define_tags=ExcelUtils.get_cell_value_as_string( - self_define_tags_cell - ), - llm_output=ExcelUtils.get_cell_value_as_string(llm_output_cell), - knowledge=ExcelUtils.get_cell_value_as_string(knowledge_cell), - prompt=ExcelUtils.get_cell_value_as_string(prompt_cell), - ) - - input_list.append(input_model) - - input_sets.data_list = input_list - except Exception as e: - logger.error(f"parse excel error, path: {path}, errorMsg: {e}") - finally: - try: - if workbook is not None: - workbook.close() - except Exception as e: - logger.error(f"close workbook error, path: {path}, errorMsg: {e}") - - return input_sets - - def parse_standard_benchmark_sets( - self, standard_excel_path: str - ) -> List[AnswerExecuteModel]: - df = pd.read_excel(standard_excel_path, sheet_name=0) - outputs: List[AnswerExecuteModel] = [] - for _, row in df.iterrows(): - try: - serial_no = int(row["编号"]) - except Exception: - continue - question = row.get("用户问题") - analysis_model_id = row.get("数据集ID") - llm_output = ( - None if pd.isna(row.get("标准答案SQL")) else str(row.get("标准答案SQL")) - ) - order_by = True - if not pd.isna(row.get("是否排序")): - try: - order_by = bool(int(row.get("是否排序"))) - except Exception: - order_by = True - - std_result: Optional[List[Dict[str, List[str]]]] = None - if not pd.isna(row.get("标准结果")): - std_result_raw = str(row.get("标准结果")) - std_result = self._parse_multi_standard_result(std_result_raw) - - strategy_config = DataCompareStrategyConfig( - strategy="CONTAIN_MATCH", - order_by=order_by, - standard_result=std_result if std_result is not None else None, - ) - outputs.append( - AnswerExecuteModel( - serialNo=serial_no, - analysisModelId=analysis_model_id, - question=question, - llmOutput=llm_output, - executeResult=std_result, - strategyConfig=strategy_config, - ) - ) - return outputs - - def write_multi_round_benchmark_result( - self, - output_file_path: str, - round_id: int, - config: BenchmarkExecuteConfig, - inputs: List[BaseInputModel], - outputs: List[OutputType], - start_index: int, - offset: int, - ) -> bool: - """ - Write the benchmark Result to Excel File With Multi Round - - Args: - output_file_path: Output file path - round_id: Round ID - config: Benchmark configuration - inputs: List of input data - outputs: List of output data - start_index: Starting index (batch start row index) - offset: Offset(file rows offset) - - Returns: - bool: Returns True if write is successful, False otherwise - """ try: # 确保输出目录存在 output_dir = Path(output_file_path).parent @@ -436,109 +370,6 @@ class ExcelFileParseService(FileParseService): logger.error(f"write excel file error: {e}", exc_info=True) return False - def write_data_compare_result( - self, - path: str, - round_id: int, - confirm_models: List[RoundAnswerConfirmModel], - is_execute: bool, - llm_count: int, - ): - """Write compare results to an Excel file - - The output Excel file will be named as '.xlsx' and - sheet name is 'benchmark_compare_result'. If the file exists, it will - append rows; otherwise it will create a new file with headers. - """ - try: - # Ensure output directory exists - output_dir = Path(path).parent - output_dir.mkdir(parents=True, exist_ok=True) - - output_file = path - - headers = [ - "serialNo", - "analysisModelId", - "question", - "selfDefineTags", - "prompt", - "standardAnswerSql", - "standardAnswer", - "llmCode", - "llmOutput", - "executeResult", - "errorMsg", - "compareResult", - ] - - # Load or create workbook and sheet - if Path(output_file).exists(): - workbook = load_workbook(str(output_file)) - if "benchmark_compare_result" in workbook.sheetnames: - worksheet = workbook["benchmark_compare_result"] - else: - worksheet = workbook.create_sheet("benchmark_compare_result") - # Write headers if new sheet - for col_idx, header in enumerate(headers, 1): - worksheet.cell(row=1, column=col_idx, value=header) - else: - workbook = Workbook() - worksheet = workbook.active - worksheet.title = "benchmark_compare_result" - # Write headers - for col_idx, header in enumerate(headers, 1): - worksheet.cell(row=1, column=col_idx, value=header) - - # Determine start row to append - start_row = worksheet.max_row + 1 if worksheet.max_row else 2 - - # Append rows - for idx, cm in enumerate(confirm_models): - row_data = [ - cm.serialNo, - cm.analysisModelId, - cm.question, - cm.selfDefineTags, - cm.prompt, - cm.standardAnswerSql, - self._format_set_result(cm.strategyConfig.standard_result) - if cm.strategyConfig is not None - else "", - cm.llmCode, - cm.llmOutput, - json.dumps(cm.executeResult, ensure_ascii=False) - if cm.executeResult is not None - else "", - cm.errorMsg, - cm.compareResult.value if cm.compareResult else None, - ] - for col_idx, value in enumerate(row_data, 1): - worksheet.cell(row=start_row + idx, column=col_idx, value=value) - - # Autosize columns (simple strategy) - for column in worksheet.columns: - max_length = 0 - column_letter = column[0].column_letter - for cell in column: - try: - if cell.value and len(str(cell.value)) > max_length: - max_length = len(str(cell.value)) - except Exception: - pass - adjusted_width = min(max(max_length + 2, 10), 80) - worksheet.column_dimensions[column_letter].width = adjusted_width - - workbook.save(str(output_file)) - workbook.close() - logger.info( - f"[write_data_compare_result] compare written to Excel: {output_file}" - ) - except Exception as e: - logger.error( - f"[write_data_compare_result] write excel error for path={path}: {e}" - ) - def _get_value_by_source_type( self, field: str, @@ -636,46 +467,6 @@ class ExcelFileParseService(FileParseService): else: return str(value) if value is not None else "" - def _parse_multi_standard_result( - self, std_result_raw: str - ) -> Optional[List[Dict[str, List[str]]]]: - """ - Parse multiple standard results from raw string data. - - Handles multiple results separated by newlines and parses each line as a dict. - - Args: - std_result_raw (str): Raw standard result string with multiple lines - - Returns: - Optional[List[Dict[str, List[str]]]]: List of parsed dictionaries, - or None if parsing fails or no valid data - """ - try: - std_result_raw = std_result_raw.strip() - if not std_result_raw: - return None - - # 处理多个结果,通过换行符分隔 - result_lines = std_result_raw.split("\n") - result_list = [] - - for line in result_lines: - line = line.strip() - if line: - try: - result_list.append(json.loads(line)) - except Exception as e: - logger.warning( - f"Failed to parse line as JSON: {line}, error: {e}" - ) - continue - - return result_list if result_list else None - except Exception as e: - logger.error(f"parse multiple standard results error: {e}") - return None - def _format_set_result( self, sql_result: List[Dict[str, List[str]]] ) -> Optional[str]: diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/user_input_execute_service.py b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/user_input_execute_service.py index 763729d32..7cd47503c 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/user_input_execute_service.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/user_input_execute_service.py @@ -4,7 +4,6 @@ import logging import os from typing import Dict, List, Optional, Union -from dbgpt.util.benchmarks import StorageUtil from dbgpt_serve.evaluate.db.benchmark_db import BenchmarkResultDao from dbgpt_serve.evaluate.service.fetchdata.benchmark_data_manager import ( BENCHMARK_DEFAULT_DB_SCHEMA, @@ -12,6 +11,8 @@ from dbgpt_serve.evaluate.service.fetchdata.benchmark_data_manager import ( ) from .data_compare_service import DataCompareService +from .ext.excel_file_parse import ExcelFileParseService +from .ext.falcon_file_parse import FalconFileParseService from .file_parse_service import FileParseService from .models import ( AnswerExecuteModel, @@ -33,14 +34,28 @@ logger = logging.getLogger(__name__) class UserInputExecuteService: def __init__( - self, file_service: FileParseService, compare_service: DataCompareService + self, compare_service: DataCompareService, file_type: FileParseTypeEnum = None ): - self.file_service = file_service self.compare_service = compare_service + self.file_service = self.file_service(file_type) + self.dao = BenchmarkResultDao() # sql query timeout in seconds self.query_timeout = float(os.getenv("BENCHMARK_SQL_TIMEOUT", 360.0)) + def file_service(self, file_type: FileParseTypeEnum) -> FileParseService: + """ + Get file service instance based on file type. + + Returns: + FileParseService: File service instance + """ + if file_type == FileParseTypeEnum.GITHUB: + return FalconFileParseService() + elif file_type == FileParseTypeEnum.EXCEL: + return ExcelFileParseService() + raise NotImplementedError(f"filePraseType: {file_type} is not implemented yet") + def read_input_file( self, input_file_path: str ) -> Union[List[BaseInputModel], None]: @@ -53,15 +68,10 @@ class UserInputExecuteService: Returns: List[BaseInputModel]: Input data list """ - file_parse_type: FileParseTypeEnum = StorageUtil.get_file_parse_type( + input_sets: BenchmarkDataSets = self.file_service.parse_input_sets( input_file_path ) - if file_parse_type == FileParseTypeEnum.EXCEL: - input_sets: BenchmarkDataSets = self.file_service.parse_input_sets( - input_file_path - ) - return input_sets.data_list - return None + return input_sets.data_list def post_dispatch( self, @@ -225,14 +235,13 @@ class UserInputExecuteService: ) results = json.loads(summary_json) if summary_json else [] - dao = BenchmarkResultDao() for item in results: llm_code = item.get("llmCode") right = int(item.get("right", 0)) wrong = int(item.get("wrong", 0)) failed = int(item.get("failed", 0)) exception = int(item.get("exception", 0)) - dao.upsert_summary( + self.dao.upsert_summary( round_id, location, llm_code, diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/fetchdata/benchmark_data_manager.py b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/fetchdata/benchmark_data_manager.py index 205383b85..f3fcf387a 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/fetchdata/benchmark_data_manager.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/fetchdata/benchmark_data_manager.py @@ -4,15 +4,17 @@ import hashlib import json import logging import os +import re import shutil import tempfile import threading import time +import uuid import zipfile from concurrent.futures import ThreadPoolExecutor from concurrent.futures import TimeoutError as FutureTimeoutError from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, cast +from typing import Any, Dict, List, Optional, Tuple, cast, Literal import aiohttp from sqlalchemy import text @@ -24,6 +26,90 @@ from dbgpt_ext.datasource.rdbms.conn_sqlite import SQLiteConnector logger = logging.getLogger(__name__) +# ---- Unified model result definitions for load_file_from_github ---- +class FailureDetail(BaseModel): + line_no: int + error: str + line: str + +class Row(BaseModel): + line_no: int + data: Any + +class FileLoadResult(BaseModel): + type: Literal["jsonl", "json", "text"] + file_path: str + file_name: str + encoding: Optional[str] = None + rows: List[Row] + count: int + failed_count: int + failures: List[FailureDetail] = [] + + +class SqlFileItem(BaseModel): + """Represents a single SQL file with its ID and content""" + + sql_id: str + sql_content: str + file_path: str + file_name: str + encoding: Optional[str] = None + + +class GoldenSqlListResult(BaseModel): + """Result object for golden SQL list loading + + Provides efficient lookup by SQL ID with dict-like interface. + """ + sql_items: Dict[str, SqlFileItem] + total_count: int + failed_count: int + + def get_by_id(self, sql_id: str) -> Optional[SqlFileItem]: + """Get SQL item by ID + + Args: + sql_id: The SQL file ID (filename prefix without extension) + + Returns: + SqlFileItem if found, None otherwise + """ + return self.sql_items.get(sql_id) + + def get_sql_content(self, sql_id: str) -> Optional[str]: + """Get SQL content by ID + + Args: + sql_id: The SQL file ID (filename prefix without extension) + + Returns: + SQL content string if found, None otherwise + """ + item = self.sql_items.get(sql_id) + return item.sql_content if item else None + + def list_all_ids(self) -> List[str]: + """Get list of all SQL IDs + + Returns: + List of SQL IDs sorted alphabetically + """ + return sorted(self.sql_items.keys()) + + def __len__(self) -> int: + """Return number of successfully loaded SQL files""" + return len(self.sql_items) + + def __contains__(self, sql_id: str) -> bool: + """Check if SQL ID exists""" + return sql_id in self.sql_items + + def __iter__(self): + """Iterate over SQL items""" + return iter(self.sql_items.values()) + + BENCHMARK_DEFAULT_DB_SCHEMA = "ant_icube_dev." @@ -36,12 +122,10 @@ class BenchmarkDataConfig(BaseModel): db_path: str = os.path.join( BENCHMARK_DATA_ROOT_PATH, f"{BENCHMARK_DEFAULT_DB_SCHEMA}db" ) - table_mapping_file: str = os.path.join( - BENCHMARK_DATA_ROOT_PATH, "table_mapping.json" - ) + table_mapping_file: Optional[str] = None cache_expiry_days: int = 1 repo_url: str = "https://github.com/eosphoros-ai/Falcon" - data_dir: str = "data/source" + data_dirs: List[str] = ["dev_data/dev_databases", "test_data/dev_databases"] class BenchmarkDataManager(BaseComponent): @@ -56,7 +140,6 @@ class BenchmarkDataManager(BaseComponent): self._config = config or BenchmarkDataConfig() self._http_session: Optional[aiohttp.ClientSession] = None self._connector: Optional[SQLiteConnector] = None - self._table_mappings = self._load_mappings() self._lock = asyncio.Lock() self.temp_dir: Optional[str] = None @@ -73,9 +156,7 @@ class BenchmarkDataManager(BaseComponent): async def async_before_stop(self): try: - logger.info("BenchmarkDataManager: closing resources before stop...") await self.close() - logger.info("BenchmarkDataManager: close done.") except Exception as e: logger.warning(f"BenchmarkDataManager: close failed: {e}") @@ -120,7 +201,6 @@ class BenchmarkDataManager(BaseComponent): async def load_data(self): logger.info("BenchmarkDataManager: start load_data.") - try: if not self._config.repo_url: logger.info("BenchmarkDataManager: repo_url not set, skip auto load.") @@ -132,69 +212,16 @@ class BenchmarkDataManager(BaseComponent): logger.info( f"BenchmarkDataManager: auto loading repo {self._config.repo_url} " - f"dir={self._config.data_dir}" + f"dirs={self._config.data_dirs}" ) await get_benchmark_manager(self.system_app).load_from_github( - repo_url=self._config.repo_url, data_dir=self._config.data_dir + repo_url=self._config.repo_url, data_dirs=self._config.data_dirs ) self._startup_loaded = True logger.info("BenchmarkDataManager: auto load finished.") except Exception as e: logger.error(f"BenchmarkDataManager: auto load failed: {e}") - def _sanitize_column_name(self, name: str) -> str: - if name is None: - return "" - name = str(name).strip().strip('"').strip("'") - invalid_chars = [ - "-", - " ", - ".", - ",", - ";", - ":", - "!", - "?", - "'", - '"', - "(", - ")", - "[", - "]", - "{", - "}", - "\t", - "\r", - "\n", - "\x00", - ] - while name and name[-1] in invalid_chars: - name = name[:-1] - for ch in invalid_chars: - if ch in name: - name = name.replace(ch, "_") - while "__" in name: - name = name.replace("__", "_") - if name and not (name[0].isalpha() or name[0] == "_"): - name = "_" + name - return name.lower() - - def _sanitize_and_dedup_headers(self, headers: List[str]) -> List[str]: - sanitized: List[str] = [] - used: set = set() - for idx, h in enumerate(headers): - name = self._sanitize_column_name(h) - if not name: - name = f"col_{idx}" - base = name - k = 2 - while name in used or not name: - name = f"{base}_{k}" - k += 1 - used.add(name) - sanitized.append(name) - return sanitized - # ========================================================== # 通用查询(阻塞实现,在线程池中调用,支持超时与可中断) @@ -321,7 +348,9 @@ class BenchmarkDataManager(BaseComponent): return [dict(zip(cols, row)) for row in rows] async def load_from_github( - self, repo_url: str, data_dir: str = "data/source" + self, + repo_url: str, + data_dirs: List[str] = ["dev_data/dev_databases", "test_data/dev_databases"], ) -> Dict: """Main method to load data from GitHub repository""" try: @@ -330,14 +359,26 @@ class BenchmarkDataManager(BaseComponent): # 1. Download or use cached repository repo_dir = await self._download_repo_contents(repo_url) - # 2. Find all CSV files recursively - csv_files = self._discover_csv_files(repo_dir, data_dir) - if not csv_files: - raise ValueError("No CSV files found") - logger.info(f"Found {len(csv_files)} CSV files") + # 2. Find all SQLite files recursively in the specified data_dirs + all_sqlite_files = [] + for d_dir in data_dirs: + try: + files = self._discover_sqlite_files(repo_dir, d_dir) + logger.info(f"Found {len(files)} SQLite files in {d_dir}") + all_sqlite_files.extend(files) + except ValueError as ve: + # 如果某个目录不存在,记录警告但不中断整个流程 + logger.warning(f"Skip directory {d_dir}: {ve}") - # 3. Import to SQLite - result = await self._import_to_database(csv_files) + if not all_sqlite_files: + raise ValueError( + f"No SQLite files found in any of the directories: {data_dirs}" + ) + + logger.info(f"Total SQLite files to merge: {len(all_sqlite_files)}") + + # 3. Merge all SQLite files into the main database + result = await self._merge_sqlite_databases(all_sqlite_files) return result except Exception as e: @@ -346,6 +387,280 @@ class BenchmarkDataManager(BaseComponent): finally: self._cleanup_temp_dir() + async def load_file_from_github(self, file_name: Optional[str] = None + ) -> Optional[FileLoadResult]: + """Download and read a specified file from a GitHub repository. + + Supported file types: .json / .jsonl + `file_name` can be a relative path within the repository or a plain filename (will be searched recursively). + + Unified return structure (FileLoadResult): + - type: "json" | "jsonl" + - file_path, file_name, encoding + - rows: List[{line_no:int, data:Any}] where data is parsed JSON object + - count: total number of rows + - failed_count: number of failed lines (non-zero for jsonl or malformed json) + - failures: details for failed lines + + For JSON files: + - If the file contains a JSON array, each element becomes a Row + - If the file contains a single JSON object, it becomes one Row + - The structure is flexible and doesn't depend on specific keys + """ + try: + if not file_name or not str(file_name).strip(): + return None + + # Download or use cached repository + repo_dir = await self._download_repo_contents(self._config.repo_url) + + # Allowed file extensions + allowed_exts = {".jsonl", ".json"} + + # Pre-check extension of `file_name` (if provided), otherwise filter by allowed list later + _, requested_ext = os.path.splitext(str(file_name).lower()) + if requested_ext and requested_ext not in allowed_exts: + raise ValueError(f"Unsupported file type: {requested_ext}") + + # Handle both relative path and plain filename cases + normalized = str(file_name).strip().lstrip("/").replace("\\", os.sep) + candidate_paths: List[str] = [] + + # Prefer direct path resolution using the relative path + direct_path = os.path.join(repo_dir, normalized) + if os.path.isfile(direct_path): + ext = os.path.splitext(direct_path.lower())[1] + if not requested_ext: + if ext in allowed_exts: + candidate_paths.append(direct_path) + elif ext == requested_ext: + candidate_paths.append(direct_path) + + # If not found, recursively search by filename match + if not candidate_paths: + target_name = os.path.basename(normalized) + for root, _, files in os.walk(repo_dir): + for f in files: + if f == target_name: + full = os.path.join(root, f) + ext = os.path.splitext(f.lower())[1] + if not requested_ext: + if ext in allowed_exts: + candidate_paths.append(full) + elif ext == requested_ext: + candidate_paths.append(full) + + if not candidate_paths: + raise FileNotFoundError(f"File not found: {file_name}") + + # Choose a stable candidate (sorted by path length and lexicographical order) + chosen = sorted(candidate_paths, key=lambda p: (len(p), p))[0] + chosen_ext = os.path.splitext(chosen.lower())[1] + + # Build repository-relative path for the file (avoid returning temp local path) + rel_path = os.path.relpath(chosen, repo_dir) + rel_path_posix = rel_path.replace(os.sep, "/") + + # Try multiple encodings + encodings = ["utf-8", "iso-8859-1"] + + # Handle .json files (array or single object) + if chosen_ext == ".json": + return await self._parse_json_file( + chosen, rel_path_posix, encodings + ) + + # Handle .jsonl files (line-delimited JSON) + elif chosen_ext == ".jsonl": + return await self._parse_jsonl_file( + chosen, rel_path_posix, encodings + ) + + else: + raise ValueError(f"Unsupported file extension: {chosen_ext}") + + except Exception as e: + logger.error(f"Falcon repository Import failed: {str(e)}") + raise RuntimeError(f"Falcon repository file data loading failed: {e}") from e + finally: + self._cleanup_temp_dir() + + async def _parse_json_file( + self, file_path: str, rel_path_posix: str, encodings: List[str] + ) -> FileLoadResult: + """Parse a JSON file (array or single object). + + Args: + file_path: Absolute path to the JSON file + rel_path_posix: Repository-relative path in POSIX format + encodings: List of encodings to try + + Returns: + FileLoadResult with parsed data + """ + rows: List[Row] = [] + failures: List[FailureDetail] = [] + used_encoding: Optional[str] = None + + # Try reading with different encodings + for enc in encodings: + try: + with open(file_path, "r", encoding=enc) as f: + content = f.read() + + try: + data = json.loads(content) + + # Handle JSON array + if isinstance(data, list): + for idx, item in enumerate(data, start=1): + rows.append(Row(line_no=idx, data=item)) + # Handle single JSON object + elif isinstance(data, dict): + rows.append(Row(line_no=1, data=data)) + else: + # Handle primitive types (string, number, etc.) + rows.append(Row(line_no=1, data=data)) + + used_encoding = enc + break + + except json.JSONDecodeError as e: + failures.append( + FailureDetail( + line_no=1, + error=f"JSON decode error: {str(e)}", + line=content[:200], + ) + ) + used_encoding = enc + break + + except UnicodeDecodeError: + continue + except Exception as e: + logger.warning(f"Read json with encoding {enc} failed: {e}") + continue + + # Fallback: read as bytes and decode with ASCII ignoring errors + if used_encoding is None: + try: + with open(file_path, "rb") as f: + content = f.read().decode("ascii", errors="ignore") + + try: + data = json.loads(content) + + if isinstance(data, list): + for idx, item in enumerate(data, start=1): + rows.append(Row(line_no=idx, data=item)) + elif isinstance(data, dict): + rows.append(Row(line_no=1, data=data)) + else: + rows.append(Row(line_no=1, data=data)) + + except json.JSONDecodeError as e: + failures.append( + FailureDetail( + line_no=1, + error=f"JSON decode error: {str(e)}", + line=content[:200], + ) + ) + + used_encoding = "ascii-ignore" + except Exception as e: + raise ValueError(f"Failed to read json file: {e}") + + return FileLoadResult( + type="json", + file_path=rel_path_posix, + file_name=os.path.basename(file_path), + encoding=used_encoding, + rows=rows, + count=len(rows) + len(failures), + failed_count=len(failures), + failures=failures, + ) + + async def _parse_jsonl_file( + self, file_path: str, rel_path_posix: str, encodings: List[str] + ) -> FileLoadResult: + """Parse a JSONL file (line-delimited JSON). + + Args: + file_path: Absolute path to the JSONL file + rel_path_posix: Repository-relative path in POSIX format + encodings: List of encodings to try + + Returns: + FileLoadResult with parsed data + """ + rows: List[Row] = [] + failures: List[FailureDetail] = [] + used_encoding: Optional[str] = None + + # Prefer reading in text mode with multiple encodings + for enc in encodings: + try: + with open(file_path, "r", encoding=enc) as f: + for idx, line in enumerate(f, start=1): + s = line.strip() + if not s: + continue + try: + obj = json.loads(s) + rows.append(Row(line_no=idx, data=obj)) + except Exception as e: + failures.append( + FailureDetail( + line_no=idx, + error=str(e), + line=s[:200], + ) + ) + used_encoding = enc + break + except UnicodeDecodeError: + continue + except Exception as e: + logger.warning(f"Read jsonl with encoding {enc} failed: {e}") + continue + + # Fallback: read as bytes and decode with ASCII ignoring errors + if used_encoding is None: + try: + with open(file_path, "rb") as f: + for idx, raw_line in enumerate(f, start=1): + s = raw_line.decode("ascii", errors="ignore").strip() + if not s: + continue + try: + obj = json.loads(s) + rows.append(Row(line_no=idx, data=obj)) + except Exception as e: + failures.append( + FailureDetail( + line_no=idx, + error=str(e), + line=s[:200], + ) + ) + used_encoding = "ascii-ignore" + except Exception as e: + raise ValueError(f"Failed to read jsonl file: {e}") + + return FileLoadResult( + type="jsonl", + file_path=rel_path_posix, + file_name=os.path.basename(file_path), + encoding=used_encoding, + rows=rows, + count=(len(rows) + len(failures)), + failed_count=len(failures), + failures=failures, + ) + async def get_table_info(self) -> Dict: """Get metadata about all tables""" await self.init_connector() @@ -445,7 +760,7 @@ class BenchmarkDataManager(BaseComponent): return (mapped_name or "").lower() async def _download_repo_contents(self, repo_url: str) -> str: - """Download repository with caching""" + """Download repository with caching, supporting branch URLs""" cache_path = self._get_cache_path(repo_url) # Use cache if valid @@ -455,21 +770,45 @@ class BenchmarkDataManager(BaseComponent): # Download fresh copy self.temp_dir = tempfile.mkdtemp() - zip_url = ( - repo_url.replace("github.com", "api.github.com/repos") + "/zipball/main" - ) + + # Simple parsing for github.com URLs + github_pattern = r"github\.com/([^/]+)/([^/]+)(?:/tree/(.+))?" + match = re.search(github_pattern, repo_url) + + if match: + owner, repo, branch = match.groups() + branch = branch or "main" # Default to main if no tree/branch specified + zip_url = f"https://api.github.com/repos/{owner}/{repo}/zipball/{branch}" + else: + # Fallback for generic structure or direct zip links + if repo_url.endswith(".zip"): + zip_url = repo_url + else: + # Default fallback behavior from original code + zip_url = ( + repo_url.replace("github.com", "api.github.com/repos") + + "/zipball/main" + ) + logger.info(f"Downloading from GitHub repo: {zip_url}") try: if self._http_session is None: self._http_session = aiohttp.ClientSession() - async with self._http_session.get(zip_url) as response: - response.raise_for_status() + + headers = {"Accept": "application/vnd.github.v3+json"} + async with self._http_session.get(zip_url, headers=headers) as response: + if response.status != 200: + text_resp = await response.text() + raise RuntimeError( + f"GitHub API Error {response.status}: {text_resp}" + ) + zip_path = os.path.join(self.temp_dir, "repo.zip") with open(zip_path, "wb") as f: while True: - chunk = await response.content.read(1024) + chunk = await response.content.read(1024 * 1024) # 1MB chunks if not chunk: break f.write(chunk) @@ -479,7 +818,6 @@ class BenchmarkDataManager(BaseComponent): logger.info(f"Saved repository to cache: {cache_path}") return self._extract_zip(zip_path) - except Exception as e: self._cleanup_temp_dir() raise RuntimeError(f"Failed to download repository: {str(e)}") from e @@ -515,252 +853,112 @@ class BenchmarkDataManager(BaseComponent): raise ValueError("No valid directory found after extraction") return os.path.join(self.temp_dir, extracted_dirs[0]) - def _discover_csv_files(self, base_dir: str, search_dir: str) -> List[Dict]: - """Find all CSV files recursively""" + def _discover_sqlite_files(self, base_dir: str, search_dir: str) -> List[str]: + """Find all SQLite files recursively in the search directory""" full_search_dir = os.path.join(base_dir, search_dir) if search_dir else base_dir if not os.path.exists(full_search_dir): raise ValueError(f"Directory not found: {full_search_dir}") - csv_files = [] + sqlite_files = [] for root, _, files in os.walk(full_search_dir): for file in files: - if file.lower().endswith(".csv"): - rel_path = os.path.relpath(root, start=base_dir) - csv_files.append( - { - "full_path": os.path.join(root, file), - "rel_path": rel_path, - "file_name": file, - } - ) - return csv_files + if file.lower().endswith(".sqlite"): + full_path = os.path.join(root, file) + sqlite_files.append(full_path) + return sqlite_files - async def _import_to_database(self, csv_files: List[Dict]) -> Dict: - """Import CSV data to SQLite""" + async def _merge_sqlite_databases(self, sqlite_files: List[str]) -> Dict: + """Merge multiple SQLite files into the main database""" await self.init_connector() assert self._connector is not None - results = { - "total_files": len(csv_files), - "successful": 0, - "failed": 0, - "tables_created": [], - } - def _process_one_file(file_info: Dict) -> Tuple[bool, Optional[str]]: - table_name = "" - try: - path_parts = [p for p in file_info["rel_path"].split(os.sep) if p] - table_name = "_".join(path_parts + [Path(file_info["file_name"]).stem]) - table_name = self._sanitize_table_name(table_name) - - with self._connector.session_scope() as session: - session.execute(text(f'DROP TABLE IF EXISTS "{table_name}"')) - session.commit() - encodings = ["utf-8-sig", "utf-8", "latin-1", "iso-8859-1", "cp1252"] - - for encoding in encodings: - try: - with open(file_info["full_path"], "r", encoding=encoding) as f: - content = f.read() - - if not content.strip(): - raise ValueError("File is empty") - - content = content.replace("\r\n", "\n").replace("\r", "\n") - lines = [line for line in content.split("\n") if line.strip()] - if not lines: - raise ValueError("No data after normalization") - - header_line = lines[0] - data_line = lines[1] if len(lines) > 1 else "" - - try: - sample_for_sniff = "\n".join(lines[:10]) - sniffer = csv.Sniffer() - try: - dialect = sniffer.sniff(sample_for_sniff) - except Exception: - # Fallback: choose delimiter by counting common - # separators in header/data line - delims = [",", "\t", ";", "|"] - counts = { - d: (header_line.count(d) if header_line else 0) - + (data_line.count(d) if data_line else 0) - for d in delims - } - best = ( - max(counts, key=counts.get) - if any(counts.values()) - else "," - ) - - class _DefaultDialect(csv.Dialect): - delimiter = best - quotechar = '"' - doublequote = True - skipinitialspace = False - lineterminator = "\n" - quoting = csv.QUOTE_MINIMAL - - dialect = _DefaultDialect() - - try: - has_header = sniffer.has_header("\n".join(lines[:50])) - except Exception: - has_header = True - - header_row = ( - list(csv.reader([header_line], dialect))[0] - if header_line - else [] - ) - first_data_row = ( - list(csv.reader([data_line], dialect))[0] - if data_line - else [] - ) - - # Heuristic: if has_header is False but header_row looks - # like names (mostly alphabetic), treat as header - if not has_header: - - def _looks_like_header(tokens: List[str]) -> bool: - if not tokens: - return False - # 非空、重复少、字母比例高 - cleaned = [ - str(t).strip() for t in tokens if str(t).strip() - ] - if not cleaned: - return False - # 允许少量数字,但大多以字母开头 - alpha_starts = sum( - 1 - for t in cleaned - if t and (t[0].isalpha() or t[0] == "_") - ) - return alpha_starts >= max( - 1, int(0.6 * len(cleaned)) - ) - - if _looks_like_header(header_row): - has_header = True - - if not has_header: - num_cols_guess = len(header_row) - headers = [f"col_{i}" for i in range(num_cols_guess)] - first_data_row = header_row - else: - headers = header_row - - num_cols = ( - len(first_data_row) if first_data_row else len(headers) - ) - - # no header - if not headers or all( - (not str(h).strip()) for h in headers - ): - headers = [f"col_{i}" for i in range(num_cols or 1)] - - headers = self._sanitize_and_dedup_headers(headers) - - if num_cols <= 0: - num_cols = len(headers) - headers = headers[:num_cols] - if not headers or any( - h is None or h == "" for h in headers - ): - raise csv.Error("Invalid headers after sanitization") - - create_sql = f''' - CREATE TABLE IF NOT EXISTS "{table_name}" ( - {", ".join([f'"{h}" TEXT' for h in headers])} - ) - ''' - insert_sql = f''' - INSERT INTO "{table_name}" ({ - ", ".join([f'"{h}"' for h in headers]) - }) - VALUES ({ - ", ".join([":" + f"p{i}" for i in range(len(headers))]) - }) - ''' - - with self._connector.session_scope() as session: - logger.debug( - f"Table: {table_name}, headers(final): {headers}" - ) - session.execute(text(create_sql)) - - reader = csv.reader(lines, dialect) - if has_header: - next(reader, None) - - batch_params: List[Dict[str, Any]] = [] - for row in reader: - if not row: - continue - if len(row) != len(headers): - if len(row) < len(headers): - row += [None] * (len(headers) - len(row)) - else: - row = row[: len(headers)] - params = { - f"p{i}": (row[i] if i < len(row) else None) - for i in range(len(headers)) - } - batch_params.append(params) - if len(batch_params) >= 1000: - session.execute(text(insert_sql), batch_params) - batch_params = [] - if batch_params: - session.execute(text(insert_sql), batch_params) - session.commit() - - return True, table_name - - except csv.Error: - self._import_with_simple_split_blocking(table_name, content) - return True, table_name - - except UnicodeDecodeError: - continue - except Exception as e: - logger.warning(f"Error with encoding {encoding}: {str(e)}") - continue + def _worker(): + results = { + "total_files": len(sqlite_files), + "successful": 0, + "failed": 0, + "tables_merged": [], + } + with self._connector.session_scope() as session: + # 获取底层的 sqlite3 连接对象 + connection_proxy = session.connection() + # 兼容不同版本的 SQLAlchemy 获取底层连接的方式 try: - with open(file_info["full_path"], "rb") as f: - content = f.read().decode("ascii", errors="ignore") - if content.strip(): - self._import_with_simple_split_blocking(table_name, content) - return True, table_name - else: - raise ValueError("File is empty or unreadable") - except Exception as e: - return ( - False, - f"Failed to process {file_info['file_name']}: {str(e)}", - ) + # SQLAlchemy 1.4+ / 2.0 + raw_conn = connection_proxy.connection.dbapi_connection + except AttributeError: + try: + # 旧版本或某些驱动 + raw_conn = connection_proxy.connection + except AttributeError: + # 最后的尝试 + raw_conn = session.get_bind().raw_connection() - except Exception as e: - return ( - False, - f"Failed to process {file_info.get('full_path', '')}: {str(e)}", - ) + # 确保 raw_conn 是 sqlite3 的连接对象 + if not raw_conn: + raise RuntimeError("Failed to get raw sqlite3 connection") - for file_info in csv_files: - ok, info = await self._run_in_thread(_process_one_file, file_info) - if ok: - results["successful"] += 1 - if info: - results["tables_created"].append(info) - else: - results["failed"] += 1 - logger.error(info) + cursor = raw_conn.cursor() - return results + for db_path in sqlite_files: + src_alias = f"src_db_{uuid.uuid4().hex[:8]}" + try: + try: + cursor.execute("PRAGMA database_list") + attached_dbs = cursor.fetchall() + for _, name, _ in attached_dbs: + if name not in ("main", "temp"): + cursor.execute(f"DETACH DATABASE {name}") + except Exception as cleanup_err: + logger.warning(f"Cleanup warning: {cleanup_err}") + + cursor.execute(f"ATTACH DATABASE ? AS {src_alias}", (db_path,)) + + cursor.execute( + f"SELECT name, sql FROM {src_alias}.sqlite_master " + f"WHERE type='table' AND name NOT LIKE 'sqlite_%'" + ) + tables = cursor.fetchall() + + for table_name, create_sql in tables: + cursor.execute( + "SELECT name FROM sqlite_master " + "WHERE type='table' " + "AND name=?", + (table_name,), + ) + if not cursor.fetchone(): + cursor.execute(create_sql) + cursor.execute( + f'INSERT INTO main."{table_name}" ' + f'SELECT * FROM {src_alias}."{table_name}"' + ) + results["tables_merged"].append(table_name) + else: + logger.warning( + f"Table '{table_name}' exists. Skipping." + ) + + raw_conn.commit() + results["successful"] += 1 + + except Exception as e: + logger.error(f"Failed to merge {db_path}: {e}") + results["failed"] += 1 + try: + raw_conn.rollback() + except Exception: + pass + finally: + try: + cursor.execute(f"DETACH DATABASE {src_alias}") + except Exception: + pass + + return results + + return await self._run_in_thread(_worker) def _import_with_simple_split_blocking(self, table_name: str, content: str): """Fallback method for malformed CSV files (blocking, 使用 SQLAlchemy 执行)"""