diff --git a/packages/dbgpt-core/src/dbgpt/util/benchmarks/StorageUtil.py b/packages/dbgpt-core/src/dbgpt/util/benchmarks/StorageUtil.py
index 4776fd002..8968789c8 100644
--- a/packages/dbgpt-core/src/dbgpt/util/benchmarks/StorageUtil.py
+++ b/packages/dbgpt-core/src/dbgpt/util/benchmarks/StorageUtil.py
@@ -10,6 +10,8 @@ class StorageUtil:
YUQUE_URL_PREFIX = "https://yuque.com"
+ GITHUB_FALCON_PREFIX = "https://github.com/eosphoros-ai/Falcon"
+
@staticmethod
def get_file_parse_type(file_path: Optional[str]) -> FileParseTypeEnum:
"""Get file parsing type based on file path
@@ -28,5 +30,7 @@ class StorageUtil:
if file_path.strip().startswith(StorageUtil.YUQUE_URL_PREFIX):
return FileParseTypeEnum.YU_QUE
+ if file_path.strip().startswith(StorageUtil.GITHUB_FALCON_PREFIX):
+ return FileParseTypeEnum.GITHUB
return FileParseTypeEnum.EXCEL
diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/__init__.py b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/__init__.py
index 185e84d9b..800c5afb2 100644
--- a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/__init__.py
+++ b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/__init__.py
@@ -1,11 +1,5 @@
from .benchmark_service import BenchmarkService
-from .data_compare_service import DataCompareService
-from .file_parse_service import FileParseService
-from .user_input_execute_service import UserInputExecuteService
__all__ = [
"BenchmarkService",
- "FileParseService",
- "UserInputExecuteService",
- "DataCompareService",
]
diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/benchmark_service.py b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/benchmark_service.py
index 410b351bb..6ca77169e 100644
--- a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/benchmark_service.py
+++ b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/benchmark_service.py
@@ -18,6 +18,7 @@ from dbgpt.model import DefaultLLMClient
from dbgpt.model.cluster import WorkerManagerFactory
from dbgpt.storage.metadata import BaseDao
from dbgpt.util import PaginationResult, get_or_create_event_loop
+from dbgpt.util.benchmarks import StorageUtil
from dbgpt_serve.evaluate.service.benchmark.task.benchmark_agent_task import (
BenchmarkAgentTask,
)
@@ -40,7 +41,6 @@ from ...config import ServeConfig
from ...models.models import ServeDao, ServeEntity
from ..fetchdata.benchmark_data_manager import get_benchmark_manager
from .data_compare_service import DataCompareService
-from .ext.excel_file_parse import ExcelFileParseService
from .models import (
BaseInputModel,
BenchmarkDataSets,
@@ -49,7 +49,6 @@ from .models import (
BenchmarkModeTypeEnum,
BenchmarkTaskResult,
ContentTypeEnum,
- FileParseTypeEnum,
InputType,
OutputType,
)
@@ -61,10 +60,7 @@ executor = ThreadPoolExecutor(max_workers=5)
BENCHMARK_SERVICE_COMPONENT_NAME = "dbgpt_serve_evaluate_benchmark_service"
-STANDARD_BENCHMARK_FILE_PATH = os.path.join(
- BENCHMARK_DATA_ROOT_PATH,
- "2025_07_27_public_500_standard_benchmark_question_list.xlsx",
-)
+STANDARD_BENCHMARK_FILE_PATH = "https://github.com/eosphoros-ai/Falcon"
BENCHMARK_OUTPUT_RESULT_PATH = os.path.join(BENCHMARK_DATA_ROOT_PATH, "result")
@@ -94,11 +90,10 @@ class BenchmarkService(
super().__init__(system_app)
self.rag_service = get_rag_service(system_app)
self.prompt_service = get_prompt_service(system_app)
- self._file_parse_type = FileParseTypeEnum.EXCEL
+ self._file_parse_type = StorageUtil.get_file_parse_type(STANDARD_BENCHMARK_FILE_PATH)
- fps = ExcelFileParseService()
dcs = DataCompareService()
- self.user_input_execute_service = UserInputExecuteService(fps, dcs)
+ self.user_input_execute_service = UserInputExecuteService(dcs, self._file_parse_type)
self.trigger_executor = ThreadPoolExecutor(
max_workers=5, thread_name_prefix="benchmark-fileWrite"
@@ -289,7 +284,7 @@ class BenchmarkService(
await manager.load_data()
logger.info(
f"Benchmark dataset loaded from {manager._config.repo_url} "
- f"dir={manager._config.data_dir}"
+ f"dir={manager._config.data_dirs}"
)
except Exception as e:
logger.error(
diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/ext/__init__.py b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/ext/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/ext/excel_file_parse.py b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/ext/excel_file_parse.py
new file mode 100644
index 000000000..618706756
--- /dev/null
+++ b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/ext/excel_file_parse.py
@@ -0,0 +1,176 @@
+import json
+import logging
+from pathlib import Path
+from typing import Any, Dict, List, Optional
+
+import pandas as pd
+from openpyxl import Workbook, load_workbook
+
+from dbgpt.util.benchmarks.ExcelUtils import ExcelUtils
+
+from ..file_parse_service import FileParseService
+from ..models import (
+ AnswerExecuteModel,
+ BaseInputModel,
+ BenchmarkDataSets,
+ DataCompareStrategyConfig,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class ExcelFileParseService(FileParseService):
+ def parse_input_sets(self, path: str) -> BenchmarkDataSets:
+ """
+ Parse input sets from excel file
+ Args:
+ path: File location path
+ Returns:
+ BenchmarkDataSets: Parsed data sets
+ """
+ input_stream = self.get_input_stream(path)
+
+ if input_stream is None:
+ raise RuntimeError(f"file not found! path: {path}")
+
+ # Parse excel file to get data sets
+ input_sets = BenchmarkDataSets()
+ workbook = None
+
+ try:
+ workbook = load_workbook(input_stream, data_only=True)
+ input_list = []
+
+ # Get the first worksheet
+ sheet = workbook.worksheets[0]
+
+ for row_num in range(
+ 2, sheet.max_row + 1
+ ): # Skip header row (start from 1-based index)
+ row = sheet[row_num]
+ if ExcelUtils.is_row_empty(row):
+ continue
+
+ # Get content from columns 1-6 (0-based index 0-5)
+ serial_no_cell = row[0]
+ analysis_model_id_cell = row[1]
+ question_cell = row[2]
+ self_define_tags_cell = row[3]
+ knowledge_cell = row[4]
+ llm_output_cell = row[5]
+ prompt_cell = row[8]
+
+ # Build input model
+ input_model = BaseInputModel(
+ serial_no=int(
+ ExcelUtils.get_cell_value_as_string(serial_no_cell) or "0"
+ ),
+ analysis_model_id=ExcelUtils.get_cell_value_as_string(
+ analysis_model_id_cell
+ ),
+ question=ExcelUtils.get_cell_value_as_string(question_cell),
+ self_define_tags=ExcelUtils.get_cell_value_as_string(
+ self_define_tags_cell
+ ),
+ llm_output=ExcelUtils.get_cell_value_as_string(llm_output_cell),
+ knowledge=ExcelUtils.get_cell_value_as_string(knowledge_cell),
+ prompt=ExcelUtils.get_cell_value_as_string(prompt_cell),
+ )
+
+ input_list.append(input_model)
+
+ input_sets.data_list = input_list
+ except Exception as e:
+ logger.error(f"parse excel error, path: {path}, errorMsg: {e}")
+ finally:
+ try:
+ if workbook is not None:
+ workbook.close()
+ except Exception as e:
+ logger.error(f"close workbook error, path: {path}, errorMsg: {e}")
+
+ return input_sets
+
+ def parse_standard_benchmark_sets(
+ self, standard_excel_path: str
+ ) -> List[AnswerExecuteModel]:
+ df = pd.read_excel(standard_excel_path, sheet_name=0)
+ outputs: List[AnswerExecuteModel] = []
+ for _, row in df.iterrows():
+ try:
+ serial_no = int(row["编号"])
+ except Exception:
+ continue
+ question = row.get("用户问题")
+ analysis_model_id = row.get("数据集ID")
+ llm_output = (
+ None if pd.isna(row.get("标准答案SQL")) else str(row.get("标准答案SQL"))
+ )
+ order_by = True
+ if not pd.isna(row.get("是否排序")):
+ try:
+ order_by = bool(int(row.get("是否排序")))
+ except Exception:
+ order_by = True
+
+ std_result: Optional[List[Dict[str, List[str]]]] = None
+ if not pd.isna(row.get("标准结果")):
+ std_result_raw = str(row.get("标准结果"))
+ std_result = self._parse_multi_standard_result(std_result_raw)
+
+ strategy_config = DataCompareStrategyConfig(
+ strategy="CONTAIN_MATCH",
+ order_by=order_by,
+ standard_result=std_result if std_result is not None else None,
+ )
+ outputs.append(
+ AnswerExecuteModel(
+ serialNo=serial_no,
+ analysisModelId=analysis_model_id,
+ question=question,
+ llmOutput=llm_output,
+ executeResult=std_result,
+ strategyConfig=strategy_config,
+ )
+ )
+ return outputs
+
+ def _parse_multi_standard_result(
+ self, std_result_raw: str
+ ) -> Optional[List[Dict[str, List[str]]]]:
+ """
+ Parse multiple standard results from raw string data.
+
+ Handles multiple results separated by newlines and parses each line as a dict.
+
+ Args:
+ std_result_raw (str): Raw standard result string with multiple lines
+
+ Returns:
+ Optional[List[Dict[str, List[str]]]]: List of parsed dictionaries,
+ or None if parsing fails or no valid data
+ """
+ try:
+ std_result_raw = std_result_raw.strip()
+ if not std_result_raw:
+ return None
+
+ # 处理多个结果,通过换行符分隔
+ result_lines = std_result_raw.split("\n")
+ result_list = []
+
+ for line in result_lines:
+ line = line.strip()
+ if line:
+ try:
+ result_list.append(json.loads(line))
+ except Exception as e:
+ logger.warning(
+ f"Failed to parse line as JSON: {line}, error: {e}"
+ )
+ continue
+
+ return result_list if result_list else None
+ except Exception as e:
+ logger.error(f"parse multiple standard results error: {e}")
+ return None
diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/ext/falcon_file_parse.py b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/ext/falcon_file_parse.py
new file mode 100644
index 000000000..b1b26a10b
--- /dev/null
+++ b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/ext/falcon_file_parse.py
@@ -0,0 +1,572 @@
+import asyncio
+import logging
+from dataclasses import dataclass
+from typing import Dict, List, Optional, Tuple
+
+from dbgpt.util import get_or_create_event_loop
+from dbgpt_serve.evaluate.service.benchmark.file_parse_service import FileParseService
+from dbgpt_serve.evaluate.service.benchmark.models import (
+ BaseInputModel,
+ BenchmarkDataSets, AnswerExecuteModel, DataCompareStrategyConfig,
+)
+from dbgpt_serve.evaluate.service.fetchdata.benchmark_data_manager import (
+ FileLoadResult,
+ GoldenSqlListResult,
+ get_benchmark_manager,
+)
+
+logger = logging.getLogger(__name__)
+
+TEXT_SQL_PROMPT = """
+Given the following dataset, including field names and sampling information:
+{Schema}
+
+Containing the following knowledge:
+{Knowledge}
+
+Based on the above information, please generate the SQL for the following question:
+{Query}
+
+[Output Requirements]
+* Field names in the generated SQL must use the actual field names from the table schema.
+* Table names in the generated SQL must use the actual table names provided in the schema.
+* Physical field names in the generated SQL must originate from the corresponding physical tables; generating fields that do not exist in the table is not allowed.
+* Cartesian product calculations are not allowed in the generated SQL. This includes `CROSS JOIN`, `JOIN` operations missing `ON` or `USING` conditions, and multi-table joins without conditions set for all relationships between tables.
+* The generated SQL must strictly adhere to {dialect} syntax. If it does not comply with this syntax, please regenerate it.
+* Output only pure, executable SQL without any additional information.
+
+[Example]
+** Table 1 Information
+*** Table Name: orders
+*** DDL Statement:
+ CREATE TABLE "orders" (
+ "order_id" TEXT,
+ "customer_id" TEXT,
+ "item_type" TEXT,
+ "order_date" DATE
+ )
+*** Field Information:
+|Field Name|Field Type|Sample Data|
+|:--:|:--:|:--:|
+|order_id|text|CN-2025-2562,CN-2025-8623,CN-2025-6535|
+|customer_id|text|52351,56263,71252|
+|item_type|text|办公用品,设备,家具|
+|order_date|date|2023-02-21,2024-03-30,2024-12-20|
+
+User Question: 请帮我统计下各商品类型的订单数
+Final Output SQL:
+SELECT
+ item_type,
+ COUNT(*) as order_count
+FROM
+ orders
+GROUP BY
+ item_type;
+"""
+
+
+@dataclass
+class BenchmarkDataItem:
+ """benchmark data info item"""
+
+ question_id: int
+ db_id: str
+ question: str
+ sql: str
+ answer: List[Dict[str, List[str]]]
+ is_order: str
+
+ @staticmethod
+ def from_dict(data: dict) -> "BenchmarkDataItem":
+ return BenchmarkDataItem(
+ question_id=data.get("question_id", 0),
+ db_id=str(data.get("db_id", "")),
+ question=str(data.get("question", "")),
+ sql=str(data.get("SQL", "")),
+ answer=data.get("answer", []),
+ is_order=data.get("is_order", "0"),
+ )
+
+@dataclass
+class ColumnItem:
+ """column info Item"""
+
+ column_id: int
+ column_name: str
+ column_type: str
+ sample_values: list
+
+ @staticmethod
+ def from_dict(data: dict) -> "ColumnItem":
+ """从字典创建 ColumnItem 实例
+
+ Args:
+ data: 包含列信息的字典
+
+ Returns:
+ ColumnItem: 列信息实例
+ """
+ return ColumnItem(
+ column_id=data.get("column_id", 0),
+ column_name=data.get("column_name", ""),
+ column_type=data.get("column_type", ""),
+ sample_values=data.get("sample_values", []),
+ )
+
+
+@dataclass
+class TableDDLItem:
+ """Table DDL Info Item"""
+
+ table_id: int
+ table_name: str
+ columns: List[ColumnItem]
+ ddl: Optional[str] = None
+
+ @staticmethod
+ def from_dict(data: dict) -> "TableDDLItem":
+ """从字典创建 TableDDLItem 实例
+
+ Args:
+ data: 包含表信息的字典
+
+ Returns:
+ TableDDLItem: 表信息实例
+ """
+ columns_data = data.get("columns", [])
+ columns = [ColumnItem.from_dict(col) for col in columns_data]
+
+ return TableDDLItem(
+ table_id=data.get("table_id", 0),
+ table_name=data.get("table_name", ""),
+ columns=columns,
+ ddl=data.get("ddl"),
+ )
+
+
+@dataclass
+class TableDataItem:
+ """Table Data Info Item"""
+
+ db_id: str
+ table_ddl: List[TableDDLItem]
+
+ @staticmethod
+ def from_dict(data: dict) -> "TableDataItem":
+ """从字典创建 TableDataItem 实例
+
+ Args:
+ data: 包含数据库表信息的字典
+
+ Returns:
+ TableDataItem: 数据库表信息实例
+ """
+ tables_data = data.get("tables", [])
+ table_ddl = [TableDDLItem.from_dict(table) for table in tables_data]
+
+ return TableDataItem(
+ db_id=str(data.get("db_id", "")),
+ table_ddl=table_ddl,
+ )
+
+class SafeDict(dict):
+ def __missing__(self, key):
+ return '{' + key + '}'
+
+
+class FalconFileParseService(FileParseService):
+ def __init__(self):
+ super().__init__()
+ self._dev_data_file = "dev_data/dev.json"
+ self._dev_table_ddl_file = "dev_data/tables.json"
+
+ self.benchmark_manager = get_benchmark_manager()
+
+ self._dev_data: Optional[FileLoadResult] = None
+ self._dev_table_ddl: Optional[FileLoadResult] = None
+ self._data_loaded = False
+
+ @staticmethod
+ def _format_answer_list(answer_list: List[Dict[str, List[str]]]) -> str:
+ """格式化 answer 列表为字符串
+
+ Args:
+ answer_list: 答案列表,每个元素是字典,字典的值是字符串列表
+
+ Returns:
+ str: JSON 格式的字符串,如果列表为空则返回空字符串
+ """
+ if not answer_list:
+ return ""
+
+ try:
+ import json
+ # 将答案列表转换为 JSON 字符串,每个答案一行
+ return "\n".join(json.dumps(item, ensure_ascii=False) for item in answer_list)
+ except Exception as e:
+ logger.warning(f"Failed to format answer list: {e}")
+ return str(answer_list)
+
+ def _ensure_data_loaded(self):
+ if self._data_loaded:
+ return
+ logger.info("Loading benchmark data for the first time...")
+ try:
+ self._dev_data, self._dev_table_ddl = self._load_data_sync()
+ self._data_loaded = True
+ logger.info("Benchmark data loaded successfully")
+ except Exception as e:
+ logger.error(f"Failed to load benchmark data: {e}", exc_info=True)
+ raise RuntimeError(f"Failed to load benchmark data: {e}")
+
+
+ def parse_input_sets(self, path: str) -> BenchmarkDataSets:
+ """
+ Parse input sets from github repo
+ Args:
+ path: File URL path
+ Returns:
+ BenchmarkDataSets: Parsed data sets
+ """
+ self._ensure_data_loaded()
+
+ try:
+ # 1. 解析评测数据
+ benchmark_data_list = self._parse_benchmark_data(self._dev_data)
+ if not benchmark_data_list:
+ logger.error("Failed to parse benchmark data")
+ return BenchmarkDataSets(data_list=[])
+
+ # 2. 解析表结构
+ table_ddl_data_list = self._parse_table_ddl_data(self._dev_table_ddl)
+ if not table_ddl_data_list:
+ logger.error("Failed to parse talbe ddl data")
+ return BenchmarkDataSets(data_list=[])
+ table_ddl_data_map = {x.db_id: x.table_ddl for x in table_ddl_data_list}
+
+ # 3. 将问题数据转换为 BaseInputModel 列表,并关联标准答案
+ input_models = []
+ for idx, question_item in enumerate(benchmark_data_list, start=1):
+ input_model = BaseInputModel(
+ serial_no=question_item.question_id,
+ analysis_model_id=question_item.db_id,
+ question=question_item.question,
+ self_define_tags="",
+ knowledge="",
+ llm_output=self._format_answer_list(question_item.answer),
+ prompt=self.load_benchmark_prompt_template(question_item, table_ddl_data_map.get(question_item.db_id)),
+ )
+ input_models.append(input_model)
+ logger.info(f"Successfully parsed {len(input_models)} question items")
+ return BenchmarkDataSets(data_list=input_models)
+ except Exception as e:
+ logger.error(
+ f"load remote benchmark data error, error: {str(e)}",
+ exc_info=True,
+ )
+ return BenchmarkDataSets(data_list=[])
+
+ def parse_standard_benchmark_sets(
+ self, standard_excel_path: str
+ ) -> List[AnswerExecuteModel]:
+
+ self._ensure_data_loaded()
+
+ outputs: List[AnswerExecuteModel] = []
+ # 1. 解析评测数据
+ benchmark_data_list = self._parse_benchmark_data(self._dev_data)
+ if not benchmark_data_list:
+ logger.error("Failed to parse benchmark data")
+ return outputs
+
+ for idx, question_item in enumerate(benchmark_data_list, start=1):
+ serial_no = question_item.question_id
+ question = question_item.question
+ analysis_model_id = question_item.db_id
+ llm_output = question_item.sql
+ order_by = True
+ if question_item.is_order:
+ try:
+ order_by = bool(int(question_item.is_order))
+ except Exception:
+ order_by = True
+
+ std_result: Optional[List[Dict[str, List[str]]]] = None
+ if question_item.answer:
+ std_result = self._parse_multi_standard_result(question_item.answer)
+
+ strategy_config = DataCompareStrategyConfig(
+ strategy="CONTAIN_MATCH",
+ order_by=order_by,
+ standard_result=std_result if std_result is not None else None,
+ )
+ outputs.append(
+ AnswerExecuteModel(
+ serialNo=serial_no,
+ analysisModelId=analysis_model_id,
+ question=question,
+ llmOutput=llm_output,
+ executeResult=std_result,
+ strategyConfig=strategy_config,
+ )
+ )
+ return outputs
+
+ def _parse_benchmark_data(
+ self, benchmark_data: Optional[FileLoadResult]
+ ) -> Optional[List[BenchmarkDataItem]]:
+ """
+ 解析问题数据
+ Args:
+ benchmark_data: 从 GitHub 加载的问题文件数据
+ Returns:
+ List[BenchmarkDataItem]: 问题数据列表,如果解析失败返回 None
+ """
+ if not benchmark_data or not benchmark_data.rows:
+ return None
+
+ if benchmark_data.failed_count > 0:
+ logger.warning(
+ f"Question data has {benchmark_data.failed_count} failed rows"
+ )
+
+ benchmark_data_list = []
+ for row in benchmark_data.rows:
+ if not isinstance(row.data, dict):
+ logger.warning(
+ f"Row {row.line_no} data is not a dict: {type(row.data)}"
+ )
+ continue
+
+ benchmark_data_item = BenchmarkDataItem.from_dict(row.data)
+
+ if (
+ not benchmark_data_item.question_id
+ or not benchmark_data_item.question
+ or not benchmark_data_item.db_id
+ ):
+ logger.warning(
+ f"Row {row.line_no} missing required fields: "
+ f"question_id={benchmark_data_item.question_id}, "
+ f"question={benchmark_data_item.question}, "
+ f"db_id={benchmark_data_item.db_id}"
+ )
+ continue
+
+ benchmark_data_list.append(benchmark_data_item)
+
+ if not benchmark_data_list:
+ logger.error("No valid benchmark data parsed")
+ return None
+
+ logger.info(
+ f"Successfully parsed {len(benchmark_data_list)} benchmark data"
+ f" from {len(benchmark_data.rows)} rows"
+ )
+ # TODO 临时只取前 100 条数据
+ return benchmark_data_list[:100]
+
+
+ def _parse_table_ddl_data(
+ self, table_ddl_data: Optional[FileLoadResult]
+ ) -> Optional[List[TableDataItem]]:
+ """
+ 解析表 DDL 数据
+ Args:
+ table_ddl_data: 从 GitHub 加载的表 DDL 数据
+ Returns:
+ List[TableDataItem]: 表 DDL 数据列表,如果解析失败返回 None
+ """
+ if not table_ddl_data or not table_ddl_data.rows:
+ logger.warning("table ddl data is None")
+ return None
+ if table_ddl_data.failed_count > 0:
+ logger.warning(
+ f"table ddl data has {table_ddl_data.failed_count} failed items"
+ )
+
+ table_ddl_data_list = []
+ for row in table_ddl_data.rows:
+ if not isinstance(row.data, dict):
+ logger.warning(
+ f"Row {row.line_no} data is not a dict: {type(row.data)}"
+ )
+ continue
+
+ table_ddl_data_item = TableDataItem.from_dict(row.data)
+
+ if (
+ not table_ddl_data_item.db_id
+ or not table_ddl_data_item.table_ddl
+ ):
+ logger.warning(
+ f"Row {row.line_no} missing required fields: "
+ f"db_id={table_ddl_data_item.db_id}, "
+ f"table_ddl={table_ddl_data_item.table_ddl}"
+ )
+ continue
+
+ table_ddl_data_list.append(table_ddl_data_item)
+
+ if not table_ddl_data_list:
+ return None
+
+ logger.info(
+ f"Successfully parsed {len(table_ddl_data_list)} table DDL data"
+ f" from {len(table_ddl_data.rows)} rows"
+ )
+ return table_ddl_data_list
+
+
+ async def _async_load_data(
+ self,
+ ) -> Tuple[
+ Optional[FileLoadResult],
+ Optional[FileLoadResult],
+ ]:
+ """并发加载两个文件数据
+
+ 使用 asyncio.gather 并发执行两个异步任务,提高加载效率
+
+ Returns:
+ Tuple: (dev_data, dev_table_ddl)
+ """
+ dev_data, dev_table_ddl = await asyncio.gather(
+ self.benchmark_manager.load_file_from_github(self._dev_data_file),
+ self.benchmark_manager.load_file_from_github(self._dev_table_ddl_file),
+ )
+ return dev_data, dev_table_ddl
+
+ def _load_data_sync(
+ self,
+ ) -> Tuple[
+ Optional[FileLoadResult],
+ Optional[FileLoadResult],
+ ]:
+ """在同步上下文中加载数据
+
+ 智能检测当前事件循环状态:
+ - 如果事件循环正在运行,使用线程池在新线程中执行异步代码
+ - 如果没有运行中的事件循环,直接使用 run_until_complete
+
+ Returns:
+ Tuple: (dev_data, dev_table_ddl)
+ """
+ try:
+ # 尝试获取当前运行中的事件循环
+ asyncio.get_running_loop()
+ # 如果到这里说明有运行中的循环,使用线程池执行
+ import concurrent.futures
+ with concurrent.futures.ThreadPoolExecutor() as executor:
+ future = executor.submit(asyncio.run, self._async_load_data())
+ return future.result()
+ except RuntimeError:
+ # 没有运行中的循环,正常执行
+ loop = get_or_create_event_loop()
+ return loop.run_until_complete(self._async_load_data())
+
+ def _run_in_new_loop(
+ self,
+ ) -> Tuple[
+ Optional[FileLoadResult],
+ Optional[FileLoadResult],
+ ]:
+ """
+ 在新的事件循环中运行异步任务
+
+ Returns:
+ Tuple: (dev_data, dev_table_ddl)
+ """
+ new_loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(new_loop)
+ try:
+ return new_loop.run_until_complete(self._async_load_data())
+ finally:
+ new_loop.close()
+
+ @staticmethod
+ def _build_table_schema_info(table_ddl_list: Optional[List[TableDDLItem]]) -> str:
+ """构建表的 Schema 信息
+
+ Args:
+ table_ddl_list: 表 DDL 信息列表
+
+ Returns:
+ str: 格式化的表 Schema 信息字符串
+ """
+ if not table_ddl_list:
+ return ""
+
+ schema_parts = []
+
+ for idx, table in enumerate(table_ddl_list, start=1):
+ # 表头信息
+ schema_parts.append(f"** Table {idx} Information")
+ schema_parts.append(f"*** Table Name: {table.table_name}")
+
+ # 如果有 DDL,添加 DDL 信息
+ if table.ddl:
+ schema_parts.append(f"*** DDL Statement:")
+ # DDL 可能是多行的,需要缩进处理
+ ddl_lines = table.ddl.strip().split('\n')
+ for ddl_line in ddl_lines:
+ schema_parts.append(f" {ddl_line}")
+
+ # 列信息表头 - 使用更清晰的格式
+ schema_parts.append("*** Field Information:")
+ schema_parts.append("|Field Name|Field Type|Sample Data|")
+ schema_parts.append("|:--:|:--:|:--:|")
+
+ # 添加每一列的信息
+ for column in table.columns:
+ # 格式化样本值 - 限制在合理长度内
+ if column.sample_values:
+ # 取前3-5个样本值,用逗号连接,避免过长
+ sample_count = min(3, len(column.sample_values))
+ sample_str = ",".join(str(val) for val in column.sample_values[:sample_count])
+ # 如果样本值过长,截断
+ if len(sample_str) > 100:
+ sample_str = sample_str[:97] + "..."
+ else:
+ sample_str = "-"
+
+ schema_parts.append(f"|{column.column_name}|{column.column_type}|{sample_str}|")
+
+ # 表之间添加空行分隔
+ if idx < len(table_ddl_list):
+ schema_parts.append("")
+
+ return "\n".join(schema_parts)
+
+ def _parse_multi_standard_result(
+ self, answer_list: List[Dict[str, List[str]]]
+ ) -> Optional[List[Dict[str, List[str]]]]:
+ """
+ 解析标准答案结果
+
+ Args:
+ answer_list: 答案列表,已经是正确的格式
+
+ Returns:
+ Optional[List[Dict[str, List[str]]]]: 返回答案列表,如果为空返回 None
+ """
+ try:
+ if not answer_list:
+ return None
+ return answer_list if answer_list else None
+ except Exception as e:
+ logger.error(f"parse standard results error: {e}")
+ return None
+
+ def load_benchmark_prompt_template(self, question_item: BenchmarkDataItem, table_ddl: List[TableDDLItem]) -> str:
+ """
+ build benchmark prompt template
+ """
+ schema = self._build_table_schema_info(table_ddl)
+ format_params = {
+ "Schema": schema,
+ "Knowledge": "",
+ "Query": question_item.question
+ }
+ # 使用 SafeDict 和 format_map 实现非严格模式,缺失的变量不会报错
+ return TEXT_SQL_PROMPT.format_map(SafeDict(format_params))
\ No newline at end of file
diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/file_parse_service.py b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/file_parse_service.py
index 27bc8f881..0498b7211 100644
--- a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/file_parse_service.py
+++ b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/file_parse_service.py
@@ -4,20 +4,16 @@ import logging
import os
from abc import ABC, abstractmethod
from pathlib import Path
-from typing import Any, Dict, List, Optional
+from typing import List, Optional, Any, Dict
import pandas as pd
from openpyxl import Workbook, load_workbook
-from dbgpt.util.benchmarks.ExcelUtils import ExcelUtils
-from dbgpt_serve.evaluate.db.benchmark_db import BenchmarkResultDao
-
from .models import (
AnswerExecuteModel,
BaseInputModel,
BenchmarkDataSets,
BenchmarkExecuteConfig,
- DataCompareStrategyConfig,
OutputType,
RoundAnswerConfirmModel,
)
@@ -27,8 +23,6 @@ logger = logging.getLogger(__name__)
class FileParseService(ABC):
def __init__(self):
- self._benchmark_dao = BenchmarkResultDao()
-
# export column configuration file path
self._column_config_file_path = os.path.join(
os.path.dirname(__file__),
@@ -56,7 +50,6 @@ class FileParseService(ABC):
data.append(AnswerExecuteModel.from_dict(obj))
return data
- @abstractmethod
def write_data_compare_result(
self,
path: str,
@@ -65,15 +58,100 @@ class FileParseService(ABC):
is_execute: bool,
llm_count: int,
):
- """Write compare results to File
+ """Write compare results to an Excel file
- Args:
- path: Output file path
- round_id: Round ID
- confirm_models: List of answer confirm models
- is_execute: Whether to execute the comparison
- llm_count: LLM count
+ The output Excel file will be named as '.xlsx' and
+ sheet name is 'benchmark_compare_result'. If the file exists, it will
+ append rows; otherwise it will create a new file with headers.
"""
+ try:
+ # Ensure output directory exists
+ output_dir = Path(path).parent
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ output_file = path
+
+ headers = [
+ "serialNo",
+ "analysisModelId",
+ "question",
+ "selfDefineTags",
+ "prompt",
+ "standardAnswerSql",
+ "standardAnswer",
+ "llmCode",
+ "llmOutput",
+ "executeResult",
+ "errorMsg",
+ "compareResult",
+ ]
+
+ # Load or create workbook and sheet
+ if Path(output_file).exists():
+ workbook = load_workbook(str(output_file))
+ if "benchmark_compare_result" in workbook.sheetnames:
+ worksheet = workbook["benchmark_compare_result"]
+ else:
+ worksheet = workbook.create_sheet("benchmark_compare_result")
+ # Write headers if new sheet
+ for col_idx, header in enumerate(headers, 1):
+ worksheet.cell(row=1, column=col_idx, value=header)
+ else:
+ workbook = Workbook()
+ worksheet = workbook.active
+ worksheet.title = "benchmark_compare_result"
+ # Write headers
+ for col_idx, header in enumerate(headers, 1):
+ worksheet.cell(row=1, column=col_idx, value=header)
+
+ # Determine start row to append
+ start_row = worksheet.max_row + 1 if worksheet.max_row else 2
+
+ # Append rows
+ for idx, cm in enumerate(confirm_models):
+ row_data = [
+ cm.serialNo,
+ cm.analysisModelId,
+ cm.question,
+ cm.selfDefineTags,
+ cm.prompt,
+ cm.standardAnswerSql,
+ self._format_set_result(cm.strategyConfig.standard_result)
+ if cm.strategyConfig is not None
+ else "",
+ cm.llmCode,
+ cm.llmOutput,
+ json.dumps(cm.executeResult, ensure_ascii=False)
+ if cm.executeResult is not None
+ else "",
+ cm.errorMsg,
+ cm.compareResult.value if cm.compareResult else None,
+ ]
+ for col_idx, value in enumerate(row_data, 1):
+ worksheet.cell(row=start_row + idx, column=col_idx, value=value)
+
+ # Autosize columns (simple strategy)
+ for column in worksheet.columns:
+ max_length = 0
+ column_letter = column[0].column_letter
+ for cell in column:
+ try:
+ if cell.value and len(str(cell.value)) > max_length:
+ max_length = len(str(cell.value))
+ except Exception:
+ pass
+ adjusted_width = min(max(max_length + 2, 10), 80)
+ worksheet.column_dimensions[column_letter].width = adjusted_width
+
+ workbook.save(str(output_file))
+ workbook.close()
+ logger.info(
+ f"[write_data_compare_result] compare written to Excel: {output_file}"
+ )
+ except Exception as e:
+ logger.error(
+ f"[write_data_compare_result] write excel error for path={path}: {e}"
+ )
def summary_and_write_multi_round_benchmark_result(
self, output_path: str, round_id: int
@@ -163,7 +241,6 @@ class FileParseService(ABC):
"""
pass
- @abstractmethod
def write_multi_round_benchmark_result(
self,
output_file_path: str,
@@ -186,149 +263,6 @@ class FileParseService(ABC):
start_index: Starting index (batch start row index)
offset: Offset(file rows offset)
"""
-
-
-class ExcelFileParseService(FileParseService):
- def parse_input_sets(self, path: str) -> BenchmarkDataSets:
- """
- Parse input sets from excel file
- Args:
- path: File location path
- Returns:
- BenchmarkDataSets: Parsed data sets
- """
- input_stream = self.get_input_stream(path)
-
- if input_stream is None:
- raise RuntimeError(f"file not found! path: {path}")
-
- # Parse excel file to get data sets
- input_sets = BenchmarkDataSets()
- workbook = None
-
- try:
- workbook = load_workbook(input_stream, data_only=True)
- input_list = []
-
- # Get the first worksheet
- sheet = workbook.worksheets[0]
-
- for row_num in range(
- 2, sheet.max_row + 1
- ): # Skip header row (start from 1-based index)
- row = sheet[row_num]
- if ExcelUtils.is_row_empty(row):
- continue
-
- # Get content from columns 1-6 (0-based index 0-5)
- serial_no_cell = row[0]
- analysis_model_id_cell = row[1]
- question_cell = row[2]
- self_define_tags_cell = row[3]
- knowledge_cell = row[4]
- llm_output_cell = row[5]
- prompt_cell = row[8]
-
- # Build input model
- input_model = BaseInputModel(
- serial_no=int(
- ExcelUtils.get_cell_value_as_string(serial_no_cell) or "0"
- ),
- analysis_model_id=ExcelUtils.get_cell_value_as_string(
- analysis_model_id_cell
- ),
- question=ExcelUtils.get_cell_value_as_string(question_cell),
- self_define_tags=ExcelUtils.get_cell_value_as_string(
- self_define_tags_cell
- ),
- llm_output=ExcelUtils.get_cell_value_as_string(llm_output_cell),
- knowledge=ExcelUtils.get_cell_value_as_string(knowledge_cell),
- prompt=ExcelUtils.get_cell_value_as_string(prompt_cell),
- )
-
- input_list.append(input_model)
-
- input_sets.data_list = input_list
- except Exception as e:
- logger.error(f"parse excel error, path: {path}, errorMsg: {e}")
- finally:
- try:
- if workbook is not None:
- workbook.close()
- except Exception as e:
- logger.error(f"close workbook error, path: {path}, errorMsg: {e}")
-
- return input_sets
-
- def parse_standard_benchmark_sets(
- self, standard_excel_path: str
- ) -> List[AnswerExecuteModel]:
- df = pd.read_excel(standard_excel_path, sheet_name=0)
- outputs: List[AnswerExecuteModel] = []
- for _, row in df.iterrows():
- try:
- serial_no = int(row["编号"])
- except Exception:
- continue
- question = row.get("用户问题")
- analysis_model_id = row.get("数据集ID")
- llm_output = (
- None if pd.isna(row.get("标准答案SQL")) else str(row.get("标准答案SQL"))
- )
- order_by = True
- if not pd.isna(row.get("是否排序")):
- try:
- order_by = bool(int(row.get("是否排序")))
- except Exception:
- order_by = True
-
- std_result: Optional[List[Dict[str, List[str]]]] = None
- if not pd.isna(row.get("标准结果")):
- std_result_raw = str(row.get("标准结果"))
- std_result = self._parse_multi_standard_result(std_result_raw)
-
- strategy_config = DataCompareStrategyConfig(
- strategy="CONTAIN_MATCH",
- order_by=order_by,
- standard_result=std_result if std_result is not None else None,
- )
- outputs.append(
- AnswerExecuteModel(
- serialNo=serial_no,
- analysisModelId=analysis_model_id,
- question=question,
- llmOutput=llm_output,
- executeResult=std_result,
- strategyConfig=strategy_config,
- )
- )
- return outputs
-
- def write_multi_round_benchmark_result(
- self,
- output_file_path: str,
- round_id: int,
- config: BenchmarkExecuteConfig,
- inputs: List[BaseInputModel],
- outputs: List[OutputType],
- start_index: int,
- offset: int,
- ) -> bool:
- """
- Write the benchmark Result to Excel File With Multi Round
-
- Args:
- output_file_path: Output file path
- round_id: Round ID
- config: Benchmark configuration
- inputs: List of input data
- outputs: List of output data
- start_index: Starting index (batch start row index)
- offset: Offset(file rows offset)
-
- Returns:
- bool: Returns True if write is successful, False otherwise
- """
try:
# 确保输出目录存在
output_dir = Path(output_file_path).parent
@@ -436,109 +370,6 @@ class ExcelFileParseService(FileParseService):
logger.error(f"write excel file error: {e}", exc_info=True)
return False
- def write_data_compare_result(
- self,
- path: str,
- round_id: int,
- confirm_models: List[RoundAnswerConfirmModel],
- is_execute: bool,
- llm_count: int,
- ):
- """Write compare results to an Excel file
-
- The output Excel file will be named as '.xlsx' and
- sheet name is 'benchmark_compare_result'. If the file exists, it will
- append rows; otherwise it will create a new file with headers.
- """
- try:
- # Ensure output directory exists
- output_dir = Path(path).parent
- output_dir.mkdir(parents=True, exist_ok=True)
-
- output_file = path
-
- headers = [
- "serialNo",
- "analysisModelId",
- "question",
- "selfDefineTags",
- "prompt",
- "standardAnswerSql",
- "standardAnswer",
- "llmCode",
- "llmOutput",
- "executeResult",
- "errorMsg",
- "compareResult",
- ]
-
- # Load or create workbook and sheet
- if Path(output_file).exists():
- workbook = load_workbook(str(output_file))
- if "benchmark_compare_result" in workbook.sheetnames:
- worksheet = workbook["benchmark_compare_result"]
- else:
- worksheet = workbook.create_sheet("benchmark_compare_result")
- # Write headers if new sheet
- for col_idx, header in enumerate(headers, 1):
- worksheet.cell(row=1, column=col_idx, value=header)
- else:
- workbook = Workbook()
- worksheet = workbook.active
- worksheet.title = "benchmark_compare_result"
- # Write headers
- for col_idx, header in enumerate(headers, 1):
- worksheet.cell(row=1, column=col_idx, value=header)
-
- # Determine start row to append
- start_row = worksheet.max_row + 1 if worksheet.max_row else 2
-
- # Append rows
- for idx, cm in enumerate(confirm_models):
- row_data = [
- cm.serialNo,
- cm.analysisModelId,
- cm.question,
- cm.selfDefineTags,
- cm.prompt,
- cm.standardAnswerSql,
- self._format_set_result(cm.strategyConfig.standard_result)
- if cm.strategyConfig is not None
- else "",
- cm.llmCode,
- cm.llmOutput,
- json.dumps(cm.executeResult, ensure_ascii=False)
- if cm.executeResult is not None
- else "",
- cm.errorMsg,
- cm.compareResult.value if cm.compareResult else None,
- ]
- for col_idx, value in enumerate(row_data, 1):
- worksheet.cell(row=start_row + idx, column=col_idx, value=value)
-
- # Autosize columns (simple strategy)
- for column in worksheet.columns:
- max_length = 0
- column_letter = column[0].column_letter
- for cell in column:
- try:
- if cell.value and len(str(cell.value)) > max_length:
- max_length = len(str(cell.value))
- except Exception:
- pass
- adjusted_width = min(max(max_length + 2, 10), 80)
- worksheet.column_dimensions[column_letter].width = adjusted_width
-
- workbook.save(str(output_file))
- workbook.close()
- logger.info(
- f"[write_data_compare_result] compare written to Excel: {output_file}"
- )
- except Exception as e:
- logger.error(
- f"[write_data_compare_result] write excel error for path={path}: {e}"
- )
-
def _get_value_by_source_type(
self,
field: str,
@@ -636,46 +467,6 @@ class ExcelFileParseService(FileParseService):
else:
return str(value) if value is not None else ""
- def _parse_multi_standard_result(
- self, std_result_raw: str
- ) -> Optional[List[Dict[str, List[str]]]]:
- """
- Parse multiple standard results from raw string data.
-
- Handles multiple results separated by newlines and parses each line as a dict.
-
- Args:
- std_result_raw (str): Raw standard result string with multiple lines
-
- Returns:
- Optional[List[Dict[str, List[str]]]]: List of parsed dictionaries,
- or None if parsing fails or no valid data
- """
- try:
- std_result_raw = std_result_raw.strip()
- if not std_result_raw:
- return None
-
- # 处理多个结果,通过换行符分隔
- result_lines = std_result_raw.split("\n")
- result_list = []
-
- for line in result_lines:
- line = line.strip()
- if line:
- try:
- result_list.append(json.loads(line))
- except Exception as e:
- logger.warning(
- f"Failed to parse line as JSON: {line}, error: {e}"
- )
- continue
-
- return result_list if result_list else None
- except Exception as e:
- logger.error(f"parse multiple standard results error: {e}")
- return None
-
def _format_set_result(
self, sql_result: List[Dict[str, List[str]]]
) -> Optional[str]:
diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/user_input_execute_service.py b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/user_input_execute_service.py
index 763729d32..7cd47503c 100644
--- a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/user_input_execute_service.py
+++ b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/user_input_execute_service.py
@@ -4,7 +4,6 @@ import logging
import os
from typing import Dict, List, Optional, Union
-from dbgpt.util.benchmarks import StorageUtil
from dbgpt_serve.evaluate.db.benchmark_db import BenchmarkResultDao
from dbgpt_serve.evaluate.service.fetchdata.benchmark_data_manager import (
BENCHMARK_DEFAULT_DB_SCHEMA,
@@ -12,6 +11,8 @@ from dbgpt_serve.evaluate.service.fetchdata.benchmark_data_manager import (
)
from .data_compare_service import DataCompareService
+from .ext.excel_file_parse import ExcelFileParseService
+from .ext.falcon_file_parse import FalconFileParseService
from .file_parse_service import FileParseService
from .models import (
AnswerExecuteModel,
@@ -33,14 +34,28 @@ logger = logging.getLogger(__name__)
class UserInputExecuteService:
def __init__(
- self, file_service: FileParseService, compare_service: DataCompareService
+ self, compare_service: DataCompareService, file_type: FileParseTypeEnum = None
):
- self.file_service = file_service
self.compare_service = compare_service
+ self.file_service = self.file_service(file_type)
+ self.dao = BenchmarkResultDao()
# sql query timeout in seconds
self.query_timeout = float(os.getenv("BENCHMARK_SQL_TIMEOUT", 360.0))
+ def file_service(self, file_type: FileParseTypeEnum) -> FileParseService:
+ """
+ Get file service instance based on file type.
+
+ Returns:
+ FileParseService: File service instance
+ """
+ if file_type == FileParseTypeEnum.GITHUB:
+ return FalconFileParseService()
+ elif file_type == FileParseTypeEnum.EXCEL:
+ return ExcelFileParseService()
+ raise NotImplementedError(f"filePraseType: {file_type} is not implemented yet")
+
def read_input_file(
self, input_file_path: str
) -> Union[List[BaseInputModel], None]:
@@ -53,15 +68,10 @@ class UserInputExecuteService:
Returns:
List[BaseInputModel]: Input data list
"""
- file_parse_type: FileParseTypeEnum = StorageUtil.get_file_parse_type(
+ input_sets: BenchmarkDataSets = self.file_service.parse_input_sets(
input_file_path
)
- if file_parse_type == FileParseTypeEnum.EXCEL:
- input_sets: BenchmarkDataSets = self.file_service.parse_input_sets(
- input_file_path
- )
- return input_sets.data_list
- return None
+ return input_sets.data_list
def post_dispatch(
self,
@@ -225,14 +235,13 @@ class UserInputExecuteService:
)
results = json.loads(summary_json) if summary_json else []
- dao = BenchmarkResultDao()
for item in results:
llm_code = item.get("llmCode")
right = int(item.get("right", 0))
wrong = int(item.get("wrong", 0))
failed = int(item.get("failed", 0))
exception = int(item.get("exception", 0))
- dao.upsert_summary(
+ self.dao.upsert_summary(
round_id,
location,
llm_code,
diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/fetchdata/benchmark_data_manager.py b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/fetchdata/benchmark_data_manager.py
index 205383b85..f3fcf387a 100644
--- a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/fetchdata/benchmark_data_manager.py
+++ b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/fetchdata/benchmark_data_manager.py
@@ -4,15 +4,17 @@ import hashlib
import json
import logging
import os
+import re
import shutil
import tempfile
import threading
import time
+import uuid
import zipfile
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import TimeoutError as FutureTimeoutError
from pathlib import Path
-from typing import Any, Dict, List, Optional, Tuple, cast
+from typing import Any, Dict, List, Optional, Tuple, cast, Literal
import aiohttp
from sqlalchemy import text
@@ -24,6 +26,90 @@ from dbgpt_ext.datasource.rdbms.conn_sqlite import SQLiteConnector
logger = logging.getLogger(__name__)
+# ---- Unified model result definitions for load_file_from_github ----
+class FailureDetail(BaseModel):
+ line_no: int
+ error: str
+ line: str
+
+class Row(BaseModel):
+ line_no: int
+ data: Any
+
+class FileLoadResult(BaseModel):
+ type: Literal["jsonl", "json", "text"]
+ file_path: str
+ file_name: str
+ encoding: Optional[str] = None
+ rows: List[Row]
+ count: int
+ failed_count: int
+ failures: List[FailureDetail] = []
+
+
+class SqlFileItem(BaseModel):
+ """Represents a single SQL file with its ID and content"""
+
+ sql_id: str
+ sql_content: str
+ file_path: str
+ file_name: str
+ encoding: Optional[str] = None
+
+
+class GoldenSqlListResult(BaseModel):
+ """Result object for golden SQL list loading
+
+ Provides efficient lookup by SQL ID with dict-like interface.
+ """
+ sql_items: Dict[str, SqlFileItem]
+ total_count: int
+ failed_count: int
+
+ def get_by_id(self, sql_id: str) -> Optional[SqlFileItem]:
+ """Get SQL item by ID
+
+ Args:
+ sql_id: The SQL file ID (filename prefix without extension)
+
+ Returns:
+ SqlFileItem if found, None otherwise
+ """
+ return self.sql_items.get(sql_id)
+
+ def get_sql_content(self, sql_id: str) -> Optional[str]:
+ """Get SQL content by ID
+
+ Args:
+ sql_id: The SQL file ID (filename prefix without extension)
+
+ Returns:
+ SQL content string if found, None otherwise
+ """
+ item = self.sql_items.get(sql_id)
+ return item.sql_content if item else None
+
+ def list_all_ids(self) -> List[str]:
+ """Get list of all SQL IDs
+
+ Returns:
+ List of SQL IDs sorted alphabetically
+ """
+ return sorted(self.sql_items.keys())
+
+ def __len__(self) -> int:
+ """Return number of successfully loaded SQL files"""
+ return len(self.sql_items)
+
+ def __contains__(self, sql_id: str) -> bool:
+ """Check if SQL ID exists"""
+ return sql_id in self.sql_items
+
+ def __iter__(self):
+ """Iterate over SQL items"""
+ return iter(self.sql_items.values())
+
+
BENCHMARK_DEFAULT_DB_SCHEMA = "ant_icube_dev."
@@ -36,12 +122,10 @@ class BenchmarkDataConfig(BaseModel):
db_path: str = os.path.join(
BENCHMARK_DATA_ROOT_PATH, f"{BENCHMARK_DEFAULT_DB_SCHEMA}db"
)
- table_mapping_file: str = os.path.join(
- BENCHMARK_DATA_ROOT_PATH, "table_mapping.json"
- )
+ table_mapping_file: Optional[str] = None
cache_expiry_days: int = 1
repo_url: str = "https://github.com/eosphoros-ai/Falcon"
- data_dir: str = "data/source"
+ data_dirs: List[str] = ["dev_data/dev_databases", "test_data/dev_databases"]
class BenchmarkDataManager(BaseComponent):
@@ -56,7 +140,6 @@ class BenchmarkDataManager(BaseComponent):
self._config = config or BenchmarkDataConfig()
self._http_session: Optional[aiohttp.ClientSession] = None
self._connector: Optional[SQLiteConnector] = None
- self._table_mappings = self._load_mappings()
self._lock = asyncio.Lock()
self.temp_dir: Optional[str] = None
@@ -73,9 +156,7 @@ class BenchmarkDataManager(BaseComponent):
async def async_before_stop(self):
try:
- logger.info("BenchmarkDataManager: closing resources before stop...")
await self.close()
- logger.info("BenchmarkDataManager: close done.")
except Exception as e:
logger.warning(f"BenchmarkDataManager: close failed: {e}")
@@ -120,7 +201,6 @@ class BenchmarkDataManager(BaseComponent):
async def load_data(self):
logger.info("BenchmarkDataManager: start load_data.")
-
try:
if not self._config.repo_url:
logger.info("BenchmarkDataManager: repo_url not set, skip auto load.")
@@ -132,69 +212,16 @@ class BenchmarkDataManager(BaseComponent):
logger.info(
f"BenchmarkDataManager: auto loading repo {self._config.repo_url} "
- f"dir={self._config.data_dir}"
+ f"dirs={self._config.data_dirs}"
)
await get_benchmark_manager(self.system_app).load_from_github(
- repo_url=self._config.repo_url, data_dir=self._config.data_dir
+ repo_url=self._config.repo_url, data_dirs=self._config.data_dirs
)
self._startup_loaded = True
logger.info("BenchmarkDataManager: auto load finished.")
except Exception as e:
logger.error(f"BenchmarkDataManager: auto load failed: {e}")
- def _sanitize_column_name(self, name: str) -> str:
- if name is None:
- return ""
- name = str(name).strip().strip('"').strip("'")
- invalid_chars = [
- "-",
- " ",
- ".",
- ",",
- ";",
- ":",
- "!",
- "?",
- "'",
- '"',
- "(",
- ")",
- "[",
- "]",
- "{",
- "}",
- "\t",
- "\r",
- "\n",
- "\x00",
- ]
- while name and name[-1] in invalid_chars:
- name = name[:-1]
- for ch in invalid_chars:
- if ch in name:
- name = name.replace(ch, "_")
- while "__" in name:
- name = name.replace("__", "_")
- if name and not (name[0].isalpha() or name[0] == "_"):
- name = "_" + name
- return name.lower()
-
- def _sanitize_and_dedup_headers(self, headers: List[str]) -> List[str]:
- sanitized: List[str] = []
- used: set = set()
- for idx, h in enumerate(headers):
- name = self._sanitize_column_name(h)
- if not name:
- name = f"col_{idx}"
- base = name
- k = 2
- while name in used or not name:
- name = f"{base}_{k}"
- k += 1
- used.add(name)
- sanitized.append(name)
- return sanitized
-
# ==========================================================
# 通用查询(阻塞实现,在线程池中调用,支持超时与可中断)
@@ -321,7 +348,9 @@ class BenchmarkDataManager(BaseComponent):
return [dict(zip(cols, row)) for row in rows]
async def load_from_github(
- self, repo_url: str, data_dir: str = "data/source"
+ self,
+ repo_url: str,
+ data_dirs: List[str] = ["dev_data/dev_databases", "test_data/dev_databases"],
) -> Dict:
"""Main method to load data from GitHub repository"""
try:
@@ -330,14 +359,26 @@ class BenchmarkDataManager(BaseComponent):
# 1. Download or use cached repository
repo_dir = await self._download_repo_contents(repo_url)
- # 2. Find all CSV files recursively
- csv_files = self._discover_csv_files(repo_dir, data_dir)
- if not csv_files:
- raise ValueError("No CSV files found")
- logger.info(f"Found {len(csv_files)} CSV files")
+ # 2. Find all SQLite files recursively in the specified data_dirs
+ all_sqlite_files = []
+ for d_dir in data_dirs:
+ try:
+ files = self._discover_sqlite_files(repo_dir, d_dir)
+ logger.info(f"Found {len(files)} SQLite files in {d_dir}")
+ all_sqlite_files.extend(files)
+ except ValueError as ve:
+ # 如果某个目录不存在,记录警告但不中断整个流程
+ logger.warning(f"Skip directory {d_dir}: {ve}")
- # 3. Import to SQLite
- result = await self._import_to_database(csv_files)
+ if not all_sqlite_files:
+ raise ValueError(
+ f"No SQLite files found in any of the directories: {data_dirs}"
+ )
+
+ logger.info(f"Total SQLite files to merge: {len(all_sqlite_files)}")
+
+ # 3. Merge all SQLite files into the main database
+ result = await self._merge_sqlite_databases(all_sqlite_files)
return result
except Exception as e:
@@ -346,6 +387,280 @@ class BenchmarkDataManager(BaseComponent):
finally:
self._cleanup_temp_dir()
+ async def load_file_from_github(self, file_name: Optional[str] = None
+ ) -> Optional[FileLoadResult]:
+ """Download and read a specified file from a GitHub repository.
+
+ Supported file types: .json / .jsonl
+ `file_name` can be a relative path within the repository or a plain filename (will be searched recursively).
+
+ Unified return structure (FileLoadResult):
+ - type: "json" | "jsonl"
+ - file_path, file_name, encoding
+ - rows: List[{line_no:int, data:Any}] where data is parsed JSON object
+ - count: total number of rows
+ - failed_count: number of failed lines (non-zero for jsonl or malformed json)
+ - failures: details for failed lines
+
+ For JSON files:
+ - If the file contains a JSON array, each element becomes a Row
+ - If the file contains a single JSON object, it becomes one Row
+ - The structure is flexible and doesn't depend on specific keys
+ """
+ try:
+ if not file_name or not str(file_name).strip():
+ return None
+
+ # Download or use cached repository
+ repo_dir = await self._download_repo_contents(self._config.repo_url)
+
+ # Allowed file extensions
+ allowed_exts = {".jsonl", ".json"}
+
+ # Pre-check extension of `file_name` (if provided), otherwise filter by allowed list later
+ _, requested_ext = os.path.splitext(str(file_name).lower())
+ if requested_ext and requested_ext not in allowed_exts:
+ raise ValueError(f"Unsupported file type: {requested_ext}")
+
+ # Handle both relative path and plain filename cases
+ normalized = str(file_name).strip().lstrip("/").replace("\\", os.sep)
+ candidate_paths: List[str] = []
+
+ # Prefer direct path resolution using the relative path
+ direct_path = os.path.join(repo_dir, normalized)
+ if os.path.isfile(direct_path):
+ ext = os.path.splitext(direct_path.lower())[1]
+ if not requested_ext:
+ if ext in allowed_exts:
+ candidate_paths.append(direct_path)
+ elif ext == requested_ext:
+ candidate_paths.append(direct_path)
+
+ # If not found, recursively search by filename match
+ if not candidate_paths:
+ target_name = os.path.basename(normalized)
+ for root, _, files in os.walk(repo_dir):
+ for f in files:
+ if f == target_name:
+ full = os.path.join(root, f)
+ ext = os.path.splitext(f.lower())[1]
+ if not requested_ext:
+ if ext in allowed_exts:
+ candidate_paths.append(full)
+ elif ext == requested_ext:
+ candidate_paths.append(full)
+
+ if not candidate_paths:
+ raise FileNotFoundError(f"File not found: {file_name}")
+
+ # Choose a stable candidate (sorted by path length and lexicographical order)
+ chosen = sorted(candidate_paths, key=lambda p: (len(p), p))[0]
+ chosen_ext = os.path.splitext(chosen.lower())[1]
+
+ # Build repository-relative path for the file (avoid returning temp local path)
+ rel_path = os.path.relpath(chosen, repo_dir)
+ rel_path_posix = rel_path.replace(os.sep, "/")
+
+ # Try multiple encodings
+ encodings = ["utf-8", "iso-8859-1"]
+
+ # Handle .json files (array or single object)
+ if chosen_ext == ".json":
+ return await self._parse_json_file(
+ chosen, rel_path_posix, encodings
+ )
+
+ # Handle .jsonl files (line-delimited JSON)
+ elif chosen_ext == ".jsonl":
+ return await self._parse_jsonl_file(
+ chosen, rel_path_posix, encodings
+ )
+
+ else:
+ raise ValueError(f"Unsupported file extension: {chosen_ext}")
+
+ except Exception as e:
+ logger.error(f"Falcon repository Import failed: {str(e)}")
+ raise RuntimeError(f"Falcon repository file data loading failed: {e}") from e
+ finally:
+ self._cleanup_temp_dir()
+
+ async def _parse_json_file(
+ self, file_path: str, rel_path_posix: str, encodings: List[str]
+ ) -> FileLoadResult:
+ """Parse a JSON file (array or single object).
+
+ Args:
+ file_path: Absolute path to the JSON file
+ rel_path_posix: Repository-relative path in POSIX format
+ encodings: List of encodings to try
+
+ Returns:
+ FileLoadResult with parsed data
+ """
+ rows: List[Row] = []
+ failures: List[FailureDetail] = []
+ used_encoding: Optional[str] = None
+
+ # Try reading with different encodings
+ for enc in encodings:
+ try:
+ with open(file_path, "r", encoding=enc) as f:
+ content = f.read()
+
+ try:
+ data = json.loads(content)
+
+ # Handle JSON array
+ if isinstance(data, list):
+ for idx, item in enumerate(data, start=1):
+ rows.append(Row(line_no=idx, data=item))
+ # Handle single JSON object
+ elif isinstance(data, dict):
+ rows.append(Row(line_no=1, data=data))
+ else:
+ # Handle primitive types (string, number, etc.)
+ rows.append(Row(line_no=1, data=data))
+
+ used_encoding = enc
+ break
+
+ except json.JSONDecodeError as e:
+ failures.append(
+ FailureDetail(
+ line_no=1,
+ error=f"JSON decode error: {str(e)}",
+ line=content[:200],
+ )
+ )
+ used_encoding = enc
+ break
+
+ except UnicodeDecodeError:
+ continue
+ except Exception as e:
+ logger.warning(f"Read json with encoding {enc} failed: {e}")
+ continue
+
+ # Fallback: read as bytes and decode with ASCII ignoring errors
+ if used_encoding is None:
+ try:
+ with open(file_path, "rb") as f:
+ content = f.read().decode("ascii", errors="ignore")
+
+ try:
+ data = json.loads(content)
+
+ if isinstance(data, list):
+ for idx, item in enumerate(data, start=1):
+ rows.append(Row(line_no=idx, data=item))
+ elif isinstance(data, dict):
+ rows.append(Row(line_no=1, data=data))
+ else:
+ rows.append(Row(line_no=1, data=data))
+
+ except json.JSONDecodeError as e:
+ failures.append(
+ FailureDetail(
+ line_no=1,
+ error=f"JSON decode error: {str(e)}",
+ line=content[:200],
+ )
+ )
+
+ used_encoding = "ascii-ignore"
+ except Exception as e:
+ raise ValueError(f"Failed to read json file: {e}")
+
+ return FileLoadResult(
+ type="json",
+ file_path=rel_path_posix,
+ file_name=os.path.basename(file_path),
+ encoding=used_encoding,
+ rows=rows,
+ count=len(rows) + len(failures),
+ failed_count=len(failures),
+ failures=failures,
+ )
+
+ async def _parse_jsonl_file(
+ self, file_path: str, rel_path_posix: str, encodings: List[str]
+ ) -> FileLoadResult:
+ """Parse a JSONL file (line-delimited JSON).
+
+ Args:
+ file_path: Absolute path to the JSONL file
+ rel_path_posix: Repository-relative path in POSIX format
+ encodings: List of encodings to try
+
+ Returns:
+ FileLoadResult with parsed data
+ """
+ rows: List[Row] = []
+ failures: List[FailureDetail] = []
+ used_encoding: Optional[str] = None
+
+ # Prefer reading in text mode with multiple encodings
+ for enc in encodings:
+ try:
+ with open(file_path, "r", encoding=enc) as f:
+ for idx, line in enumerate(f, start=1):
+ s = line.strip()
+ if not s:
+ continue
+ try:
+ obj = json.loads(s)
+ rows.append(Row(line_no=idx, data=obj))
+ except Exception as e:
+ failures.append(
+ FailureDetail(
+ line_no=idx,
+ error=str(e),
+ line=s[:200],
+ )
+ )
+ used_encoding = enc
+ break
+ except UnicodeDecodeError:
+ continue
+ except Exception as e:
+ logger.warning(f"Read jsonl with encoding {enc} failed: {e}")
+ continue
+
+ # Fallback: read as bytes and decode with ASCII ignoring errors
+ if used_encoding is None:
+ try:
+ with open(file_path, "rb") as f:
+ for idx, raw_line in enumerate(f, start=1):
+ s = raw_line.decode("ascii", errors="ignore").strip()
+ if not s:
+ continue
+ try:
+ obj = json.loads(s)
+ rows.append(Row(line_no=idx, data=obj))
+ except Exception as e:
+ failures.append(
+ FailureDetail(
+ line_no=idx,
+ error=str(e),
+ line=s[:200],
+ )
+ )
+ used_encoding = "ascii-ignore"
+ except Exception as e:
+ raise ValueError(f"Failed to read jsonl file: {e}")
+
+ return FileLoadResult(
+ type="jsonl",
+ file_path=rel_path_posix,
+ file_name=os.path.basename(file_path),
+ encoding=used_encoding,
+ rows=rows,
+ count=(len(rows) + len(failures)),
+ failed_count=len(failures),
+ failures=failures,
+ )
+
async def get_table_info(self) -> Dict:
"""Get metadata about all tables"""
await self.init_connector()
@@ -445,7 +760,7 @@ class BenchmarkDataManager(BaseComponent):
return (mapped_name or "").lower()
async def _download_repo_contents(self, repo_url: str) -> str:
- """Download repository with caching"""
+ """Download repository with caching, supporting branch URLs"""
cache_path = self._get_cache_path(repo_url)
# Use cache if valid
@@ -455,21 +770,45 @@ class BenchmarkDataManager(BaseComponent):
# Download fresh copy
self.temp_dir = tempfile.mkdtemp()
- zip_url = (
- repo_url.replace("github.com", "api.github.com/repos") + "/zipball/main"
- )
+
+ # Simple parsing for github.com URLs
+ github_pattern = r"github\.com/([^/]+)/([^/]+)(?:/tree/(.+))?"
+ match = re.search(github_pattern, repo_url)
+
+ if match:
+ owner, repo, branch = match.groups()
+ branch = branch or "main" # Default to main if no tree/branch specified
+ zip_url = f"https://api.github.com/repos/{owner}/{repo}/zipball/{branch}"
+ else:
+ # Fallback for generic structure or direct zip links
+ if repo_url.endswith(".zip"):
+ zip_url = repo_url
+ else:
+ # Default fallback behavior from original code
+ zip_url = (
+ repo_url.replace("github.com", "api.github.com/repos")
+ + "/zipball/main"
+ )
+
logger.info(f"Downloading from GitHub repo: {zip_url}")
try:
if self._http_session is None:
self._http_session = aiohttp.ClientSession()
- async with self._http_session.get(zip_url) as response:
- response.raise_for_status()
+
+ headers = {"Accept": "application/vnd.github.v3+json"}
+ async with self._http_session.get(zip_url, headers=headers) as response:
+ if response.status != 200:
+ text_resp = await response.text()
+ raise RuntimeError(
+ f"GitHub API Error {response.status}: {text_resp}"
+ )
+
zip_path = os.path.join(self.temp_dir, "repo.zip")
with open(zip_path, "wb") as f:
while True:
- chunk = await response.content.read(1024)
+ chunk = await response.content.read(1024 * 1024) # 1MB chunks
if not chunk:
break
f.write(chunk)
@@ -479,7 +818,6 @@ class BenchmarkDataManager(BaseComponent):
logger.info(f"Saved repository to cache: {cache_path}")
return self._extract_zip(zip_path)
-
except Exception as e:
self._cleanup_temp_dir()
raise RuntimeError(f"Failed to download repository: {str(e)}") from e
@@ -515,252 +853,112 @@ class BenchmarkDataManager(BaseComponent):
raise ValueError("No valid directory found after extraction")
return os.path.join(self.temp_dir, extracted_dirs[0])
- def _discover_csv_files(self, base_dir: str, search_dir: str) -> List[Dict]:
- """Find all CSV files recursively"""
+ def _discover_sqlite_files(self, base_dir: str, search_dir: str) -> List[str]:
+ """Find all SQLite files recursively in the search directory"""
full_search_dir = os.path.join(base_dir, search_dir) if search_dir else base_dir
if not os.path.exists(full_search_dir):
raise ValueError(f"Directory not found: {full_search_dir}")
- csv_files = []
+ sqlite_files = []
for root, _, files in os.walk(full_search_dir):
for file in files:
- if file.lower().endswith(".csv"):
- rel_path = os.path.relpath(root, start=base_dir)
- csv_files.append(
- {
- "full_path": os.path.join(root, file),
- "rel_path": rel_path,
- "file_name": file,
- }
- )
- return csv_files
+ if file.lower().endswith(".sqlite"):
+ full_path = os.path.join(root, file)
+ sqlite_files.append(full_path)
+ return sqlite_files
- async def _import_to_database(self, csv_files: List[Dict]) -> Dict:
- """Import CSV data to SQLite"""
+ async def _merge_sqlite_databases(self, sqlite_files: List[str]) -> Dict:
+ """Merge multiple SQLite files into the main database"""
await self.init_connector()
assert self._connector is not None
- results = {
- "total_files": len(csv_files),
- "successful": 0,
- "failed": 0,
- "tables_created": [],
- }
- def _process_one_file(file_info: Dict) -> Tuple[bool, Optional[str]]:
- table_name = ""
- try:
- path_parts = [p for p in file_info["rel_path"].split(os.sep) if p]
- table_name = "_".join(path_parts + [Path(file_info["file_name"]).stem])
- table_name = self._sanitize_table_name(table_name)
-
- with self._connector.session_scope() as session:
- session.execute(text(f'DROP TABLE IF EXISTS "{table_name}"'))
- session.commit()
- encodings = ["utf-8-sig", "utf-8", "latin-1", "iso-8859-1", "cp1252"]
-
- for encoding in encodings:
- try:
- with open(file_info["full_path"], "r", encoding=encoding) as f:
- content = f.read()
-
- if not content.strip():
- raise ValueError("File is empty")
-
- content = content.replace("\r\n", "\n").replace("\r", "\n")
- lines = [line for line in content.split("\n") if line.strip()]
- if not lines:
- raise ValueError("No data after normalization")
-
- header_line = lines[0]
- data_line = lines[1] if len(lines) > 1 else ""
-
- try:
- sample_for_sniff = "\n".join(lines[:10])
- sniffer = csv.Sniffer()
- try:
- dialect = sniffer.sniff(sample_for_sniff)
- except Exception:
- # Fallback: choose delimiter by counting common
- # separators in header/data line
- delims = [",", "\t", ";", "|"]
- counts = {
- d: (header_line.count(d) if header_line else 0)
- + (data_line.count(d) if data_line else 0)
- for d in delims
- }
- best = (
- max(counts, key=counts.get)
- if any(counts.values())
- else ","
- )
-
- class _DefaultDialect(csv.Dialect):
- delimiter = best
- quotechar = '"'
- doublequote = True
- skipinitialspace = False
- lineterminator = "\n"
- quoting = csv.QUOTE_MINIMAL
-
- dialect = _DefaultDialect()
-
- try:
- has_header = sniffer.has_header("\n".join(lines[:50]))
- except Exception:
- has_header = True
-
- header_row = (
- list(csv.reader([header_line], dialect))[0]
- if header_line
- else []
- )
- first_data_row = (
- list(csv.reader([data_line], dialect))[0]
- if data_line
- else []
- )
-
- # Heuristic: if has_header is False but header_row looks
- # like names (mostly alphabetic), treat as header
- if not has_header:
-
- def _looks_like_header(tokens: List[str]) -> bool:
- if not tokens:
- return False
- # 非空、重复少、字母比例高
- cleaned = [
- str(t).strip() for t in tokens if str(t).strip()
- ]
- if not cleaned:
- return False
- # 允许少量数字,但大多以字母开头
- alpha_starts = sum(
- 1
- for t in cleaned
- if t and (t[0].isalpha() or t[0] == "_")
- )
- return alpha_starts >= max(
- 1, int(0.6 * len(cleaned))
- )
-
- if _looks_like_header(header_row):
- has_header = True
-
- if not has_header:
- num_cols_guess = len(header_row)
- headers = [f"col_{i}" for i in range(num_cols_guess)]
- first_data_row = header_row
- else:
- headers = header_row
-
- num_cols = (
- len(first_data_row) if first_data_row else len(headers)
- )
-
- # no header
- if not headers or all(
- (not str(h).strip()) for h in headers
- ):
- headers = [f"col_{i}" for i in range(num_cols or 1)]
-
- headers = self._sanitize_and_dedup_headers(headers)
-
- if num_cols <= 0:
- num_cols = len(headers)
- headers = headers[:num_cols]
- if not headers or any(
- h is None or h == "" for h in headers
- ):
- raise csv.Error("Invalid headers after sanitization")
-
- create_sql = f'''
- CREATE TABLE IF NOT EXISTS "{table_name}" (
- {", ".join([f'"{h}" TEXT' for h in headers])}
- )
- '''
- insert_sql = f'''
- INSERT INTO "{table_name}" ({
- ", ".join([f'"{h}"' for h in headers])
- })
- VALUES ({
- ", ".join([":" + f"p{i}" for i in range(len(headers))])
- })
- '''
-
- with self._connector.session_scope() as session:
- logger.debug(
- f"Table: {table_name}, headers(final): {headers}"
- )
- session.execute(text(create_sql))
-
- reader = csv.reader(lines, dialect)
- if has_header:
- next(reader, None)
-
- batch_params: List[Dict[str, Any]] = []
- for row in reader:
- if not row:
- continue
- if len(row) != len(headers):
- if len(row) < len(headers):
- row += [None] * (len(headers) - len(row))
- else:
- row = row[: len(headers)]
- params = {
- f"p{i}": (row[i] if i < len(row) else None)
- for i in range(len(headers))
- }
- batch_params.append(params)
- if len(batch_params) >= 1000:
- session.execute(text(insert_sql), batch_params)
- batch_params = []
- if batch_params:
- session.execute(text(insert_sql), batch_params)
- session.commit()
-
- return True, table_name
-
- except csv.Error:
- self._import_with_simple_split_blocking(table_name, content)
- return True, table_name
-
- except UnicodeDecodeError:
- continue
- except Exception as e:
- logger.warning(f"Error with encoding {encoding}: {str(e)}")
- continue
+ def _worker():
+ results = {
+ "total_files": len(sqlite_files),
+ "successful": 0,
+ "failed": 0,
+ "tables_merged": [],
+ }
+ with self._connector.session_scope() as session:
+ # 获取底层的 sqlite3 连接对象
+ connection_proxy = session.connection()
+ # 兼容不同版本的 SQLAlchemy 获取底层连接的方式
try:
- with open(file_info["full_path"], "rb") as f:
- content = f.read().decode("ascii", errors="ignore")
- if content.strip():
- self._import_with_simple_split_blocking(table_name, content)
- return True, table_name
- else:
- raise ValueError("File is empty or unreadable")
- except Exception as e:
- return (
- False,
- f"Failed to process {file_info['file_name']}: {str(e)}",
- )
+ # SQLAlchemy 1.4+ / 2.0
+ raw_conn = connection_proxy.connection.dbapi_connection
+ except AttributeError:
+ try:
+ # 旧版本或某些驱动
+ raw_conn = connection_proxy.connection
+ except AttributeError:
+ # 最后的尝试
+ raw_conn = session.get_bind().raw_connection()
- except Exception as e:
- return (
- False,
- f"Failed to process {file_info.get('full_path', '')}: {str(e)}",
- )
+ # 确保 raw_conn 是 sqlite3 的连接对象
+ if not raw_conn:
+ raise RuntimeError("Failed to get raw sqlite3 connection")
- for file_info in csv_files:
- ok, info = await self._run_in_thread(_process_one_file, file_info)
- if ok:
- results["successful"] += 1
- if info:
- results["tables_created"].append(info)
- else:
- results["failed"] += 1
- logger.error(info)
+ cursor = raw_conn.cursor()
- return results
+ for db_path in sqlite_files:
+ src_alias = f"src_db_{uuid.uuid4().hex[:8]}"
+ try:
+ try:
+ cursor.execute("PRAGMA database_list")
+ attached_dbs = cursor.fetchall()
+ for _, name, _ in attached_dbs:
+ if name not in ("main", "temp"):
+ cursor.execute(f"DETACH DATABASE {name}")
+ except Exception as cleanup_err:
+ logger.warning(f"Cleanup warning: {cleanup_err}")
+
+ cursor.execute(f"ATTACH DATABASE ? AS {src_alias}", (db_path,))
+
+ cursor.execute(
+ f"SELECT name, sql FROM {src_alias}.sqlite_master "
+ f"WHERE type='table' AND name NOT LIKE 'sqlite_%'"
+ )
+ tables = cursor.fetchall()
+
+ for table_name, create_sql in tables:
+ cursor.execute(
+ "SELECT name FROM sqlite_master "
+ "WHERE type='table' "
+ "AND name=?",
+ (table_name,),
+ )
+ if not cursor.fetchone():
+ cursor.execute(create_sql)
+ cursor.execute(
+ f'INSERT INTO main."{table_name}" '
+ f'SELECT * FROM {src_alias}."{table_name}"'
+ )
+ results["tables_merged"].append(table_name)
+ else:
+ logger.warning(
+ f"Table '{table_name}' exists. Skipping."
+ )
+
+ raw_conn.commit()
+ results["successful"] += 1
+
+ except Exception as e:
+ logger.error(f"Failed to merge {db_path}: {e}")
+ results["failed"] += 1
+ try:
+ raw_conn.rollback()
+ except Exception:
+ pass
+ finally:
+ try:
+ cursor.execute(f"DETACH DATABASE {src_alias}")
+ except Exception:
+ pass
+
+ return results
+
+ return await self._run_in_thread(_worker)
def _import_with_simple_split_blocking(self, table_name: str, content: str):
"""Fallback method for malformed CSV files (blocking, 使用 SQLAlchemy 执行)"""