feat(benchmark): update benchmark task use latest falcon github repo questions

This commit is contained in:
alan.cl
2025-12-22 15:03:56 +08:00
parent 3c7cfba3be
commit 1491407578
9 changed files with 1380 additions and 641 deletions

View File

@@ -10,6 +10,8 @@ class StorageUtil:
YUQUE_URL_PREFIX = "https://yuque.com"
GITHUB_FALCON_PREFIX = "https://github.com/eosphoros-ai/Falcon"
@staticmethod
def get_file_parse_type(file_path: Optional[str]) -> FileParseTypeEnum:
"""Get file parsing type based on file path
@@ -28,5 +30,7 @@ class StorageUtil:
if file_path.strip().startswith(StorageUtil.YUQUE_URL_PREFIX):
return FileParseTypeEnum.YU_QUE
if file_path.strip().startswith(StorageUtil.GITHUB_FALCON_PREFIX):
return FileParseTypeEnum.GITHUB
return FileParseTypeEnum.EXCEL

View File

@@ -1,11 +1,5 @@
from .benchmark_service import BenchmarkService
from .data_compare_service import DataCompareService
from .file_parse_service import FileParseService
from .user_input_execute_service import UserInputExecuteService
__all__ = [
"BenchmarkService",
"FileParseService",
"UserInputExecuteService",
"DataCompareService",
]

View File

@@ -18,6 +18,7 @@ from dbgpt.model import DefaultLLMClient
from dbgpt.model.cluster import WorkerManagerFactory
from dbgpt.storage.metadata import BaseDao
from dbgpt.util import PaginationResult, get_or_create_event_loop
from dbgpt.util.benchmarks import StorageUtil
from dbgpt_serve.evaluate.service.benchmark.task.benchmark_agent_task import (
BenchmarkAgentTask,
)
@@ -40,7 +41,6 @@ from ...config import ServeConfig
from ...models.models import ServeDao, ServeEntity
from ..fetchdata.benchmark_data_manager import get_benchmark_manager
from .data_compare_service import DataCompareService
from .ext.excel_file_parse import ExcelFileParseService
from .models import (
BaseInputModel,
BenchmarkDataSets,
@@ -49,7 +49,6 @@ from .models import (
BenchmarkModeTypeEnum,
BenchmarkTaskResult,
ContentTypeEnum,
FileParseTypeEnum,
InputType,
OutputType,
)
@@ -61,10 +60,7 @@ executor = ThreadPoolExecutor(max_workers=5)
BENCHMARK_SERVICE_COMPONENT_NAME = "dbgpt_serve_evaluate_benchmark_service"
STANDARD_BENCHMARK_FILE_PATH = os.path.join(
BENCHMARK_DATA_ROOT_PATH,
"2025_07_27_public_500_standard_benchmark_question_list.xlsx",
)
STANDARD_BENCHMARK_FILE_PATH = "https://github.com/eosphoros-ai/Falcon"
BENCHMARK_OUTPUT_RESULT_PATH = os.path.join(BENCHMARK_DATA_ROOT_PATH, "result")
@@ -94,11 +90,10 @@ class BenchmarkService(
super().__init__(system_app)
self.rag_service = get_rag_service(system_app)
self.prompt_service = get_prompt_service(system_app)
self._file_parse_type = FileParseTypeEnum.EXCEL
self._file_parse_type = StorageUtil.get_file_parse_type(STANDARD_BENCHMARK_FILE_PATH)
fps = ExcelFileParseService()
dcs = DataCompareService()
self.user_input_execute_service = UserInputExecuteService(fps, dcs)
self.user_input_execute_service = UserInputExecuteService(dcs, self._file_parse_type)
self.trigger_executor = ThreadPoolExecutor(
max_workers=5, thread_name_prefix="benchmark-fileWrite"
@@ -289,7 +284,7 @@ class BenchmarkService(
await manager.load_data()
logger.info(
f"Benchmark dataset loaded from {manager._config.repo_url} "
f"dir={manager._config.data_dir}"
f"dir={manager._config.data_dirs}"
)
except Exception as e:
logger.error(

View File

@@ -0,0 +1,176 @@
import json
import logging
from pathlib import Path
from typing import Any, Dict, List, Optional
import pandas as pd
from openpyxl import Workbook, load_workbook
from dbgpt.util.benchmarks.ExcelUtils import ExcelUtils
from ..file_parse_service import FileParseService
from ..models import (
AnswerExecuteModel,
BaseInputModel,
BenchmarkDataSets,
DataCompareStrategyConfig,
)
logger = logging.getLogger(__name__)
class ExcelFileParseService(FileParseService):
def parse_input_sets(self, path: str) -> BenchmarkDataSets:
"""
Parse input sets from excel file
Args:
path: File location path
Returns:
BenchmarkDataSets: Parsed data sets
"""
input_stream = self.get_input_stream(path)
if input_stream is None:
raise RuntimeError(f"file not found! path: {path}")
# Parse excel file to get data sets
input_sets = BenchmarkDataSets()
workbook = None
try:
workbook = load_workbook(input_stream, data_only=True)
input_list = []
# Get the first worksheet
sheet = workbook.worksheets[0]
for row_num in range(
2, sheet.max_row + 1
): # Skip header row (start from 1-based index)
row = sheet[row_num]
if ExcelUtils.is_row_empty(row):
continue
# Get content from columns 1-6 (0-based index 0-5)
serial_no_cell = row[0]
analysis_model_id_cell = row[1]
question_cell = row[2]
self_define_tags_cell = row[3]
knowledge_cell = row[4]
llm_output_cell = row[5]
prompt_cell = row[8]
# Build input model
input_model = BaseInputModel(
serial_no=int(
ExcelUtils.get_cell_value_as_string(serial_no_cell) or "0"
),
analysis_model_id=ExcelUtils.get_cell_value_as_string(
analysis_model_id_cell
),
question=ExcelUtils.get_cell_value_as_string(question_cell),
self_define_tags=ExcelUtils.get_cell_value_as_string(
self_define_tags_cell
),
llm_output=ExcelUtils.get_cell_value_as_string(llm_output_cell),
knowledge=ExcelUtils.get_cell_value_as_string(knowledge_cell),
prompt=ExcelUtils.get_cell_value_as_string(prompt_cell),
)
input_list.append(input_model)
input_sets.data_list = input_list
except Exception as e:
logger.error(f"parse excel error, path: {path}, errorMsg: {e}")
finally:
try:
if workbook is not None:
workbook.close()
except Exception as e:
logger.error(f"close workbook error, path: {path}, errorMsg: {e}")
return input_sets
def parse_standard_benchmark_sets(
self, standard_excel_path: str
) -> List[AnswerExecuteModel]:
df = pd.read_excel(standard_excel_path, sheet_name=0)
outputs: List[AnswerExecuteModel] = []
for _, row in df.iterrows():
try:
serial_no = int(row["编号"])
except Exception:
continue
question = row.get("用户问题")
analysis_model_id = row.get("数据集ID")
llm_output = (
None if pd.isna(row.get("标准答案SQL")) else str(row.get("标准答案SQL"))
)
order_by = True
if not pd.isna(row.get("是否排序")):
try:
order_by = bool(int(row.get("是否排序")))
except Exception:
order_by = True
std_result: Optional[List[Dict[str, List[str]]]] = None
if not pd.isna(row.get("标准结果")):
std_result_raw = str(row.get("标准结果"))
std_result = self._parse_multi_standard_result(std_result_raw)
strategy_config = DataCompareStrategyConfig(
strategy="CONTAIN_MATCH",
order_by=order_by,
standard_result=std_result if std_result is not None else None,
)
outputs.append(
AnswerExecuteModel(
serialNo=serial_no,
analysisModelId=analysis_model_id,
question=question,
llmOutput=llm_output,
executeResult=std_result,
strategyConfig=strategy_config,
)
)
return outputs
def _parse_multi_standard_result(
self, std_result_raw: str
) -> Optional[List[Dict[str, List[str]]]]:
"""
Parse multiple standard results from raw string data.
Handles multiple results separated by newlines and parses each line as a dict.
Args:
std_result_raw (str): Raw standard result string with multiple lines
Returns:
Optional[List[Dict[str, List[str]]]]: List of parsed dictionaries,
or None if parsing fails or no valid data
"""
try:
std_result_raw = std_result_raw.strip()
if not std_result_raw:
return None
# 处理多个结果,通过换行符分隔
result_lines = std_result_raw.split("\n")
result_list = []
for line in result_lines:
line = line.strip()
if line:
try:
result_list.append(json.loads(line))
except Exception as e:
logger.warning(
f"Failed to parse line as JSON: {line}, error: {e}"
)
continue
return result_list if result_list else None
except Exception as e:
logger.error(f"parse multiple standard results error: {e}")
return None

View File

@@ -0,0 +1,572 @@
import asyncio
import logging
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from dbgpt.util import get_or_create_event_loop
from dbgpt_serve.evaluate.service.benchmark.file_parse_service import FileParseService
from dbgpt_serve.evaluate.service.benchmark.models import (
BaseInputModel,
BenchmarkDataSets, AnswerExecuteModel, DataCompareStrategyConfig,
)
from dbgpt_serve.evaluate.service.fetchdata.benchmark_data_manager import (
FileLoadResult,
GoldenSqlListResult,
get_benchmark_manager,
)
logger = logging.getLogger(__name__)
TEXT_SQL_PROMPT = """
Given the following dataset, including field names and sampling information:
{Schema}
Containing the following knowledge:
{Knowledge}
Based on the above information, please generate the SQL for the following question:
{Query}
[Output Requirements]
* Field names in the generated SQL must use the actual field names from the table schema.
* Table names in the generated SQL must use the actual table names provided in the schema.
* Physical field names in the generated SQL must originate from the corresponding physical tables; generating fields that do not exist in the table is not allowed.
* Cartesian product calculations are not allowed in the generated SQL. This includes `CROSS JOIN`, `JOIN` operations missing `ON` or `USING` conditions, and multi-table joins without conditions set for all relationships between tables.
* The generated SQL must strictly adhere to {dialect} syntax. If it does not comply with this syntax, please regenerate it.
* Output only pure, executable SQL without any additional information.
[Example]
** Table 1 Information
*** Table Name: orders
*** DDL Statement:
CREATE TABLE "orders" (
"order_id" TEXT,
"customer_id" TEXT,
"item_type" TEXT,
"order_date" DATE
)
*** Field Information:
|Field Name|Field Type|Sample Data|
|:--:|:--:|:--:|
|order_id|text|CN-2025-2562,CN-2025-8623,CN-2025-6535|
|customer_id|text|52351,56263,71252|
|item_type|text|办公用品,设备,家具|
|order_date|date|2023-02-21,2024-03-30,2024-12-20|
User Question: 请帮我统计下各商品类型的订单数
Final Output SQL:
SELECT
item_type,
COUNT(*) as order_count
FROM
orders
GROUP BY
item_type;
"""
@dataclass
class BenchmarkDataItem:
"""benchmark data info item"""
question_id: int
db_id: str
question: str
sql: str
answer: List[Dict[str, List[str]]]
is_order: str
@staticmethod
def from_dict(data: dict) -> "BenchmarkDataItem":
return BenchmarkDataItem(
question_id=data.get("question_id", 0),
db_id=str(data.get("db_id", "")),
question=str(data.get("question", "")),
sql=str(data.get("SQL", "")),
answer=data.get("answer", []),
is_order=data.get("is_order", "0"),
)
@dataclass
class ColumnItem:
"""column info Item"""
column_id: int
column_name: str
column_type: str
sample_values: list
@staticmethod
def from_dict(data: dict) -> "ColumnItem":
"""从字典创建 ColumnItem 实例
Args:
data: 包含列信息的字典
Returns:
ColumnItem: 列信息实例
"""
return ColumnItem(
column_id=data.get("column_id", 0),
column_name=data.get("column_name", ""),
column_type=data.get("column_type", ""),
sample_values=data.get("sample_values", []),
)
@dataclass
class TableDDLItem:
"""Table DDL Info Item"""
table_id: int
table_name: str
columns: List[ColumnItem]
ddl: Optional[str] = None
@staticmethod
def from_dict(data: dict) -> "TableDDLItem":
"""从字典创建 TableDDLItem 实例
Args:
data: 包含表信息的字典
Returns:
TableDDLItem: 表信息实例
"""
columns_data = data.get("columns", [])
columns = [ColumnItem.from_dict(col) for col in columns_data]
return TableDDLItem(
table_id=data.get("table_id", 0),
table_name=data.get("table_name", ""),
columns=columns,
ddl=data.get("ddl"),
)
@dataclass
class TableDataItem:
"""Table Data Info Item"""
db_id: str
table_ddl: List[TableDDLItem]
@staticmethod
def from_dict(data: dict) -> "TableDataItem":
"""从字典创建 TableDataItem 实例
Args:
data: 包含数据库表信息的字典
Returns:
TableDataItem: 数据库表信息实例
"""
tables_data = data.get("tables", [])
table_ddl = [TableDDLItem.from_dict(table) for table in tables_data]
return TableDataItem(
db_id=str(data.get("db_id", "")),
table_ddl=table_ddl,
)
class SafeDict(dict):
def __missing__(self, key):
return '{' + key + '}'
class FalconFileParseService(FileParseService):
def __init__(self):
super().__init__()
self._dev_data_file = "dev_data/dev.json"
self._dev_table_ddl_file = "dev_data/tables.json"
self.benchmark_manager = get_benchmark_manager()
self._dev_data: Optional[FileLoadResult] = None
self._dev_table_ddl: Optional[FileLoadResult] = None
self._data_loaded = False
@staticmethod
def _format_answer_list(answer_list: List[Dict[str, List[str]]]) -> str:
"""格式化 answer 列表为字符串
Args:
answer_list: 答案列表,每个元素是字典,字典的值是字符串列表
Returns:
str: JSON 格式的字符串,如果列表为空则返回空字符串
"""
if not answer_list:
return ""
try:
import json
# 将答案列表转换为 JSON 字符串,每个答案一行
return "\n".join(json.dumps(item, ensure_ascii=False) for item in answer_list)
except Exception as e:
logger.warning(f"Failed to format answer list: {e}")
return str(answer_list)
def _ensure_data_loaded(self):
if self._data_loaded:
return
logger.info("Loading benchmark data for the first time...")
try:
self._dev_data, self._dev_table_ddl = self._load_data_sync()
self._data_loaded = True
logger.info("Benchmark data loaded successfully")
except Exception as e:
logger.error(f"Failed to load benchmark data: {e}", exc_info=True)
raise RuntimeError(f"Failed to load benchmark data: {e}")
def parse_input_sets(self, path: str) -> BenchmarkDataSets:
"""
Parse input sets from github repo
Args:
path: File URL path
Returns:
BenchmarkDataSets: Parsed data sets
"""
self._ensure_data_loaded()
try:
# 1. 解析评测数据
benchmark_data_list = self._parse_benchmark_data(self._dev_data)
if not benchmark_data_list:
logger.error("Failed to parse benchmark data")
return BenchmarkDataSets(data_list=[])
# 2. 解析表结构
table_ddl_data_list = self._parse_table_ddl_data(self._dev_table_ddl)
if not table_ddl_data_list:
logger.error("Failed to parse talbe ddl data")
return BenchmarkDataSets(data_list=[])
table_ddl_data_map = {x.db_id: x.table_ddl for x in table_ddl_data_list}
# 3. 将问题数据转换为 BaseInputModel 列表,并关联标准答案
input_models = []
for idx, question_item in enumerate(benchmark_data_list, start=1):
input_model = BaseInputModel(
serial_no=question_item.question_id,
analysis_model_id=question_item.db_id,
question=question_item.question,
self_define_tags="",
knowledge="",
llm_output=self._format_answer_list(question_item.answer),
prompt=self.load_benchmark_prompt_template(question_item, table_ddl_data_map.get(question_item.db_id)),
)
input_models.append(input_model)
logger.info(f"Successfully parsed {len(input_models)} question items")
return BenchmarkDataSets(data_list=input_models)
except Exception as e:
logger.error(
f"load remote benchmark data error, error: {str(e)}",
exc_info=True,
)
return BenchmarkDataSets(data_list=[])
def parse_standard_benchmark_sets(
self, standard_excel_path: str
) -> List[AnswerExecuteModel]:
self._ensure_data_loaded()
outputs: List[AnswerExecuteModel] = []
# 1. 解析评测数据
benchmark_data_list = self._parse_benchmark_data(self._dev_data)
if not benchmark_data_list:
logger.error("Failed to parse benchmark data")
return outputs
for idx, question_item in enumerate(benchmark_data_list, start=1):
serial_no = question_item.question_id
question = question_item.question
analysis_model_id = question_item.db_id
llm_output = question_item.sql
order_by = True
if question_item.is_order:
try:
order_by = bool(int(question_item.is_order))
except Exception:
order_by = True
std_result: Optional[List[Dict[str, List[str]]]] = None
if question_item.answer:
std_result = self._parse_multi_standard_result(question_item.answer)
strategy_config = DataCompareStrategyConfig(
strategy="CONTAIN_MATCH",
order_by=order_by,
standard_result=std_result if std_result is not None else None,
)
outputs.append(
AnswerExecuteModel(
serialNo=serial_no,
analysisModelId=analysis_model_id,
question=question,
llmOutput=llm_output,
executeResult=std_result,
strategyConfig=strategy_config,
)
)
return outputs
def _parse_benchmark_data(
self, benchmark_data: Optional[FileLoadResult]
) -> Optional[List[BenchmarkDataItem]]:
"""
解析问题数据
Args:
benchmark_data: 从 GitHub 加载的问题文件数据
Returns:
List[BenchmarkDataItem]: 问题数据列表,如果解析失败返回 None
"""
if not benchmark_data or not benchmark_data.rows:
return None
if benchmark_data.failed_count > 0:
logger.warning(
f"Question data has {benchmark_data.failed_count} failed rows"
)
benchmark_data_list = []
for row in benchmark_data.rows:
if not isinstance(row.data, dict):
logger.warning(
f"Row {row.line_no} data is not a dict: {type(row.data)}"
)
continue
benchmark_data_item = BenchmarkDataItem.from_dict(row.data)
if (
not benchmark_data_item.question_id
or not benchmark_data_item.question
or not benchmark_data_item.db_id
):
logger.warning(
f"Row {row.line_no} missing required fields: "
f"question_id={benchmark_data_item.question_id}, "
f"question={benchmark_data_item.question}, "
f"db_id={benchmark_data_item.db_id}"
)
continue
benchmark_data_list.append(benchmark_data_item)
if not benchmark_data_list:
logger.error("No valid benchmark data parsed")
return None
logger.info(
f"Successfully parsed {len(benchmark_data_list)} benchmark data"
f" from {len(benchmark_data.rows)} rows"
)
# TODO 临时只取前 100 条数据
return benchmark_data_list[:100]
def _parse_table_ddl_data(
self, table_ddl_data: Optional[FileLoadResult]
) -> Optional[List[TableDataItem]]:
"""
解析表 DDL 数据
Args:
table_ddl_data: 从 GitHub 加载的表 DDL 数据
Returns:
List[TableDataItem]: 表 DDL 数据列表,如果解析失败返回 None
"""
if not table_ddl_data or not table_ddl_data.rows:
logger.warning("table ddl data is None")
return None
if table_ddl_data.failed_count > 0:
logger.warning(
f"table ddl data has {table_ddl_data.failed_count} failed items"
)
table_ddl_data_list = []
for row in table_ddl_data.rows:
if not isinstance(row.data, dict):
logger.warning(
f"Row {row.line_no} data is not a dict: {type(row.data)}"
)
continue
table_ddl_data_item = TableDataItem.from_dict(row.data)
if (
not table_ddl_data_item.db_id
or not table_ddl_data_item.table_ddl
):
logger.warning(
f"Row {row.line_no} missing required fields: "
f"db_id={table_ddl_data_item.db_id}, "
f"table_ddl={table_ddl_data_item.table_ddl}"
)
continue
table_ddl_data_list.append(table_ddl_data_item)
if not table_ddl_data_list:
return None
logger.info(
f"Successfully parsed {len(table_ddl_data_list)} table DDL data"
f" from {len(table_ddl_data.rows)} rows"
)
return table_ddl_data_list
async def _async_load_data(
self,
) -> Tuple[
Optional[FileLoadResult],
Optional[FileLoadResult],
]:
"""并发加载两个文件数据
使用 asyncio.gather 并发执行两个异步任务,提高加载效率
Returns:
Tuple: (dev_data, dev_table_ddl)
"""
dev_data, dev_table_ddl = await asyncio.gather(
self.benchmark_manager.load_file_from_github(self._dev_data_file),
self.benchmark_manager.load_file_from_github(self._dev_table_ddl_file),
)
return dev_data, dev_table_ddl
def _load_data_sync(
self,
) -> Tuple[
Optional[FileLoadResult],
Optional[FileLoadResult],
]:
"""在同步上下文中加载数据
智能检测当前事件循环状态:
- 如果事件循环正在运行,使用线程池在新线程中执行异步代码
- 如果没有运行中的事件循环,直接使用 run_until_complete
Returns:
Tuple: (dev_data, dev_table_ddl)
"""
try:
# 尝试获取当前运行中的事件循环
asyncio.get_running_loop()
# 如果到这里说明有运行中的循环,使用线程池执行
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(asyncio.run, self._async_load_data())
return future.result()
except RuntimeError:
# 没有运行中的循环,正常执行
loop = get_or_create_event_loop()
return loop.run_until_complete(self._async_load_data())
def _run_in_new_loop(
self,
) -> Tuple[
Optional[FileLoadResult],
Optional[FileLoadResult],
]:
"""
在新的事件循环中运行异步任务
Returns:
Tuple: (dev_data, dev_table_ddl)
"""
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
try:
return new_loop.run_until_complete(self._async_load_data())
finally:
new_loop.close()
@staticmethod
def _build_table_schema_info(table_ddl_list: Optional[List[TableDDLItem]]) -> str:
"""构建表的 Schema 信息
Args:
table_ddl_list: 表 DDL 信息列表
Returns:
str: 格式化的表 Schema 信息字符串
"""
if not table_ddl_list:
return ""
schema_parts = []
for idx, table in enumerate(table_ddl_list, start=1):
# 表头信息
schema_parts.append(f"** Table {idx} Information")
schema_parts.append(f"*** Table Name: {table.table_name}")
# 如果有 DDL添加 DDL 信息
if table.ddl:
schema_parts.append(f"*** DDL Statement:")
# DDL 可能是多行的,需要缩进处理
ddl_lines = table.ddl.strip().split('\n')
for ddl_line in ddl_lines:
schema_parts.append(f" {ddl_line}")
# 列信息表头 - 使用更清晰的格式
schema_parts.append("*** Field Information:")
schema_parts.append("|Field Name|Field Type|Sample Data|")
schema_parts.append("|:--:|:--:|:--:|")
# 添加每一列的信息
for column in table.columns:
# 格式化样本值 - 限制在合理长度内
if column.sample_values:
# 取前3-5个样本值用逗号连接避免过长
sample_count = min(3, len(column.sample_values))
sample_str = ",".join(str(val) for val in column.sample_values[:sample_count])
# 如果样本值过长,截断
if len(sample_str) > 100:
sample_str = sample_str[:97] + "..."
else:
sample_str = "-"
schema_parts.append(f"|{column.column_name}|{column.column_type}|{sample_str}|")
# 表之间添加空行分隔
if idx < len(table_ddl_list):
schema_parts.append("")
return "\n".join(schema_parts)
def _parse_multi_standard_result(
self, answer_list: List[Dict[str, List[str]]]
) -> Optional[List[Dict[str, List[str]]]]:
"""
解析标准答案结果
Args:
answer_list: 答案列表,已经是正确的格式
Returns:
Optional[List[Dict[str, List[str]]]]: 返回答案列表,如果为空返回 None
"""
try:
if not answer_list:
return None
return answer_list if answer_list else None
except Exception as e:
logger.error(f"parse standard results error: {e}")
return None
def load_benchmark_prompt_template(self, question_item: BenchmarkDataItem, table_ddl: List[TableDDLItem]) -> str:
"""
build benchmark prompt template
"""
schema = self._build_table_schema_info(table_ddl)
format_params = {
"Schema": schema,
"Knowledge": "",
"Query": question_item.question
}
# 使用 SafeDict 和 format_map 实现非严格模式,缺失的变量不会报错
return TEXT_SQL_PROMPT.format_map(SafeDict(format_params))

View File

@@ -4,20 +4,16 @@ import logging
import os
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import List, Optional, Any, Dict
import pandas as pd
from openpyxl import Workbook, load_workbook
from dbgpt.util.benchmarks.ExcelUtils import ExcelUtils
from dbgpt_serve.evaluate.db.benchmark_db import BenchmarkResultDao
from .models import (
AnswerExecuteModel,
BaseInputModel,
BenchmarkDataSets,
BenchmarkExecuteConfig,
DataCompareStrategyConfig,
OutputType,
RoundAnswerConfirmModel,
)
@@ -27,8 +23,6 @@ logger = logging.getLogger(__name__)
class FileParseService(ABC):
def __init__(self):
self._benchmark_dao = BenchmarkResultDao()
# export column configuration file path
self._column_config_file_path = os.path.join(
os.path.dirname(__file__),
@@ -56,7 +50,6 @@ class FileParseService(ABC):
data.append(AnswerExecuteModel.from_dict(obj))
return data
@abstractmethod
def write_data_compare_result(
self,
path: str,
@@ -65,15 +58,100 @@ class FileParseService(ABC):
is_execute: bool,
llm_count: int,
):
"""Write compare results to File
"""Write compare results to an Excel file
Args:
path: Output file path
round_id: Round ID
confirm_models: List of answer confirm models
is_execute: Whether to execute the comparison
llm_count: LLM count
The output Excel file will be named as '<base>.xlsx' and
sheet name is 'benchmark_compare_result'. If the file exists, it will
append rows; otherwise it will create a new file with headers.
"""
try:
# Ensure output directory exists
output_dir = Path(path).parent
output_dir.mkdir(parents=True, exist_ok=True)
output_file = path
headers = [
"serialNo",
"analysisModelId",
"question",
"selfDefineTags",
"prompt",
"standardAnswerSql",
"standardAnswer",
"llmCode",
"llmOutput",
"executeResult",
"errorMsg",
"compareResult",
]
# Load or create workbook and sheet
if Path(output_file).exists():
workbook = load_workbook(str(output_file))
if "benchmark_compare_result" in workbook.sheetnames:
worksheet = workbook["benchmark_compare_result"]
else:
worksheet = workbook.create_sheet("benchmark_compare_result")
# Write headers if new sheet
for col_idx, header in enumerate(headers, 1):
worksheet.cell(row=1, column=col_idx, value=header)
else:
workbook = Workbook()
worksheet = workbook.active
worksheet.title = "benchmark_compare_result"
# Write headers
for col_idx, header in enumerate(headers, 1):
worksheet.cell(row=1, column=col_idx, value=header)
# Determine start row to append
start_row = worksheet.max_row + 1 if worksheet.max_row else 2
# Append rows
for idx, cm in enumerate(confirm_models):
row_data = [
cm.serialNo,
cm.analysisModelId,
cm.question,
cm.selfDefineTags,
cm.prompt,
cm.standardAnswerSql,
self._format_set_result(cm.strategyConfig.standard_result)
if cm.strategyConfig is not None
else "",
cm.llmCode,
cm.llmOutput,
json.dumps(cm.executeResult, ensure_ascii=False)
if cm.executeResult is not None
else "",
cm.errorMsg,
cm.compareResult.value if cm.compareResult else None,
]
for col_idx, value in enumerate(row_data, 1):
worksheet.cell(row=start_row + idx, column=col_idx, value=value)
# Autosize columns (simple strategy)
for column in worksheet.columns:
max_length = 0
column_letter = column[0].column_letter
for cell in column:
try:
if cell.value and len(str(cell.value)) > max_length:
max_length = len(str(cell.value))
except Exception:
pass
adjusted_width = min(max(max_length + 2, 10), 80)
worksheet.column_dimensions[column_letter].width = adjusted_width
workbook.save(str(output_file))
workbook.close()
logger.info(
f"[write_data_compare_result] compare written to Excel: {output_file}"
)
except Exception as e:
logger.error(
f"[write_data_compare_result] write excel error for path={path}: {e}"
)
def summary_and_write_multi_round_benchmark_result(
self, output_path: str, round_id: int
@@ -163,7 +241,6 @@ class FileParseService(ABC):
"""
pass
@abstractmethod
def write_multi_round_benchmark_result(
self,
output_file_path: str,
@@ -186,149 +263,6 @@ class FileParseService(ABC):
start_index: Starting index (batch start row index)
offset: Offset(file rows offset)
"""
class ExcelFileParseService(FileParseService):
def parse_input_sets(self, path: str) -> BenchmarkDataSets:
"""
Parse input sets from excel file
Args:
path: File location path
Returns:
BenchmarkDataSets: Parsed data sets
"""
input_stream = self.get_input_stream(path)
if input_stream is None:
raise RuntimeError(f"file not found! path: {path}")
# Parse excel file to get data sets
input_sets = BenchmarkDataSets()
workbook = None
try:
workbook = load_workbook(input_stream, data_only=True)
input_list = []
# Get the first worksheet
sheet = workbook.worksheets[0]
for row_num in range(
2, sheet.max_row + 1
): # Skip header row (start from 1-based index)
row = sheet[row_num]
if ExcelUtils.is_row_empty(row):
continue
# Get content from columns 1-6 (0-based index 0-5)
serial_no_cell = row[0]
analysis_model_id_cell = row[1]
question_cell = row[2]
self_define_tags_cell = row[3]
knowledge_cell = row[4]
llm_output_cell = row[5]
prompt_cell = row[8]
# Build input model
input_model = BaseInputModel(
serial_no=int(
ExcelUtils.get_cell_value_as_string(serial_no_cell) or "0"
),
analysis_model_id=ExcelUtils.get_cell_value_as_string(
analysis_model_id_cell
),
question=ExcelUtils.get_cell_value_as_string(question_cell),
self_define_tags=ExcelUtils.get_cell_value_as_string(
self_define_tags_cell
),
llm_output=ExcelUtils.get_cell_value_as_string(llm_output_cell),
knowledge=ExcelUtils.get_cell_value_as_string(knowledge_cell),
prompt=ExcelUtils.get_cell_value_as_string(prompt_cell),
)
input_list.append(input_model)
input_sets.data_list = input_list
except Exception as e:
logger.error(f"parse excel error, path: {path}, errorMsg: {e}")
finally:
try:
if workbook is not None:
workbook.close()
except Exception as e:
logger.error(f"close workbook error, path: {path}, errorMsg: {e}")
return input_sets
def parse_standard_benchmark_sets(
self, standard_excel_path: str
) -> List[AnswerExecuteModel]:
df = pd.read_excel(standard_excel_path, sheet_name=0)
outputs: List[AnswerExecuteModel] = []
for _, row in df.iterrows():
try:
serial_no = int(row["编号"])
except Exception:
continue
question = row.get("用户问题")
analysis_model_id = row.get("数据集ID")
llm_output = (
None if pd.isna(row.get("标准答案SQL")) else str(row.get("标准答案SQL"))
)
order_by = True
if not pd.isna(row.get("是否排序")):
try:
order_by = bool(int(row.get("是否排序")))
except Exception:
order_by = True
std_result: Optional[List[Dict[str, List[str]]]] = None
if not pd.isna(row.get("标准结果")):
std_result_raw = str(row.get("标准结果"))
std_result = self._parse_multi_standard_result(std_result_raw)
strategy_config = DataCompareStrategyConfig(
strategy="CONTAIN_MATCH",
order_by=order_by,
standard_result=std_result if std_result is not None else None,
)
outputs.append(
AnswerExecuteModel(
serialNo=serial_no,
analysisModelId=analysis_model_id,
question=question,
llmOutput=llm_output,
executeResult=std_result,
strategyConfig=strategy_config,
)
)
return outputs
def write_multi_round_benchmark_result(
self,
output_file_path: str,
round_id: int,
config: BenchmarkExecuteConfig,
inputs: List[BaseInputModel],
outputs: List[OutputType],
start_index: int,
offset: int,
) -> bool:
"""
Write the benchmark Result to Excel File With Multi Round
Args:
output_file_path: Output file path
round_id: Round ID
config: Benchmark configuration
inputs: List of input data
outputs: List of output data
start_index: Starting index (batch start row index)
offset: Offset(file rows offset)
Returns:
bool: Returns True if write is successful, False otherwise
"""
try:
# 确保输出目录存在
output_dir = Path(output_file_path).parent
@@ -436,109 +370,6 @@ class ExcelFileParseService(FileParseService):
logger.error(f"write excel file error: {e}", exc_info=True)
return False
def write_data_compare_result(
self,
path: str,
round_id: int,
confirm_models: List[RoundAnswerConfirmModel],
is_execute: bool,
llm_count: int,
):
"""Write compare results to an Excel file
The output Excel file will be named as '<base>.xlsx' and
sheet name is 'benchmark_compare_result'. If the file exists, it will
append rows; otherwise it will create a new file with headers.
"""
try:
# Ensure output directory exists
output_dir = Path(path).parent
output_dir.mkdir(parents=True, exist_ok=True)
output_file = path
headers = [
"serialNo",
"analysisModelId",
"question",
"selfDefineTags",
"prompt",
"standardAnswerSql",
"standardAnswer",
"llmCode",
"llmOutput",
"executeResult",
"errorMsg",
"compareResult",
]
# Load or create workbook and sheet
if Path(output_file).exists():
workbook = load_workbook(str(output_file))
if "benchmark_compare_result" in workbook.sheetnames:
worksheet = workbook["benchmark_compare_result"]
else:
worksheet = workbook.create_sheet("benchmark_compare_result")
# Write headers if new sheet
for col_idx, header in enumerate(headers, 1):
worksheet.cell(row=1, column=col_idx, value=header)
else:
workbook = Workbook()
worksheet = workbook.active
worksheet.title = "benchmark_compare_result"
# Write headers
for col_idx, header in enumerate(headers, 1):
worksheet.cell(row=1, column=col_idx, value=header)
# Determine start row to append
start_row = worksheet.max_row + 1 if worksheet.max_row else 2
# Append rows
for idx, cm in enumerate(confirm_models):
row_data = [
cm.serialNo,
cm.analysisModelId,
cm.question,
cm.selfDefineTags,
cm.prompt,
cm.standardAnswerSql,
self._format_set_result(cm.strategyConfig.standard_result)
if cm.strategyConfig is not None
else "",
cm.llmCode,
cm.llmOutput,
json.dumps(cm.executeResult, ensure_ascii=False)
if cm.executeResult is not None
else "",
cm.errorMsg,
cm.compareResult.value if cm.compareResult else None,
]
for col_idx, value in enumerate(row_data, 1):
worksheet.cell(row=start_row + idx, column=col_idx, value=value)
# Autosize columns (simple strategy)
for column in worksheet.columns:
max_length = 0
column_letter = column[0].column_letter
for cell in column:
try:
if cell.value and len(str(cell.value)) > max_length:
max_length = len(str(cell.value))
except Exception:
pass
adjusted_width = min(max(max_length + 2, 10), 80)
worksheet.column_dimensions[column_letter].width = adjusted_width
workbook.save(str(output_file))
workbook.close()
logger.info(
f"[write_data_compare_result] compare written to Excel: {output_file}"
)
except Exception as e:
logger.error(
f"[write_data_compare_result] write excel error for path={path}: {e}"
)
def _get_value_by_source_type(
self,
field: str,
@@ -636,46 +467,6 @@ class ExcelFileParseService(FileParseService):
else:
return str(value) if value is not None else ""
def _parse_multi_standard_result(
self, std_result_raw: str
) -> Optional[List[Dict[str, List[str]]]]:
"""
Parse multiple standard results from raw string data.
Handles multiple results separated by newlines and parses each line as a dict.
Args:
std_result_raw (str): Raw standard result string with multiple lines
Returns:
Optional[List[Dict[str, List[str]]]]: List of parsed dictionaries,
or None if parsing fails or no valid data
"""
try:
std_result_raw = std_result_raw.strip()
if not std_result_raw:
return None
# 处理多个结果,通过换行符分隔
result_lines = std_result_raw.split("\n")
result_list = []
for line in result_lines:
line = line.strip()
if line:
try:
result_list.append(json.loads(line))
except Exception as e:
logger.warning(
f"Failed to parse line as JSON: {line}, error: {e}"
)
continue
return result_list if result_list else None
except Exception as e:
logger.error(f"parse multiple standard results error: {e}")
return None
def _format_set_result(
self, sql_result: List[Dict[str, List[str]]]
) -> Optional[str]:

View File

@@ -4,7 +4,6 @@ import logging
import os
from typing import Dict, List, Optional, Union
from dbgpt.util.benchmarks import StorageUtil
from dbgpt_serve.evaluate.db.benchmark_db import BenchmarkResultDao
from dbgpt_serve.evaluate.service.fetchdata.benchmark_data_manager import (
BENCHMARK_DEFAULT_DB_SCHEMA,
@@ -12,6 +11,8 @@ from dbgpt_serve.evaluate.service.fetchdata.benchmark_data_manager import (
)
from .data_compare_service import DataCompareService
from .ext.excel_file_parse import ExcelFileParseService
from .ext.falcon_file_parse import FalconFileParseService
from .file_parse_service import FileParseService
from .models import (
AnswerExecuteModel,
@@ -33,14 +34,28 @@ logger = logging.getLogger(__name__)
class UserInputExecuteService:
def __init__(
self, file_service: FileParseService, compare_service: DataCompareService
self, compare_service: DataCompareService, file_type: FileParseTypeEnum = None
):
self.file_service = file_service
self.compare_service = compare_service
self.file_service = self.file_service(file_type)
self.dao = BenchmarkResultDao()
# sql query timeout in seconds
self.query_timeout = float(os.getenv("BENCHMARK_SQL_TIMEOUT", 360.0))
def file_service(self, file_type: FileParseTypeEnum) -> FileParseService:
"""
Get file service instance based on file type.
Returns:
FileParseService: File service instance
"""
if file_type == FileParseTypeEnum.GITHUB:
return FalconFileParseService()
elif file_type == FileParseTypeEnum.EXCEL:
return ExcelFileParseService()
raise NotImplementedError(f"filePraseType: {file_type} is not implemented yet")
def read_input_file(
self, input_file_path: str
) -> Union[List[BaseInputModel], None]:
@@ -53,15 +68,10 @@ class UserInputExecuteService:
Returns:
List[BaseInputModel]: Input data list
"""
file_parse_type: FileParseTypeEnum = StorageUtil.get_file_parse_type(
input_sets: BenchmarkDataSets = self.file_service.parse_input_sets(
input_file_path
)
if file_parse_type == FileParseTypeEnum.EXCEL:
input_sets: BenchmarkDataSets = self.file_service.parse_input_sets(
input_file_path
)
return input_sets.data_list
return None
return input_sets.data_list
def post_dispatch(
self,
@@ -225,14 +235,13 @@ class UserInputExecuteService:
)
results = json.loads(summary_json) if summary_json else []
dao = BenchmarkResultDao()
for item in results:
llm_code = item.get("llmCode")
right = int(item.get("right", 0))
wrong = int(item.get("wrong", 0))
failed = int(item.get("failed", 0))
exception = int(item.get("exception", 0))
dao.upsert_summary(
self.dao.upsert_summary(
round_id,
location,
llm_code,

View File

@@ -4,15 +4,17 @@ import hashlib
import json
import logging
import os
import re
import shutil
import tempfile
import threading
import time
import uuid
import zipfile
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import TimeoutError as FutureTimeoutError
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, cast
from typing import Any, Dict, List, Optional, Tuple, cast, Literal
import aiohttp
from sqlalchemy import text
@@ -24,6 +26,90 @@ from dbgpt_ext.datasource.rdbms.conn_sqlite import SQLiteConnector
logger = logging.getLogger(__name__)
# ---- Unified model result definitions for load_file_from_github ----
class FailureDetail(BaseModel):
line_no: int
error: str
line: str
class Row(BaseModel):
line_no: int
data: Any
class FileLoadResult(BaseModel):
type: Literal["jsonl", "json", "text"]
file_path: str
file_name: str
encoding: Optional[str] = None
rows: List[Row]
count: int
failed_count: int
failures: List[FailureDetail] = []
class SqlFileItem(BaseModel):
"""Represents a single SQL file with its ID and content"""
sql_id: str
sql_content: str
file_path: str
file_name: str
encoding: Optional[str] = None
class GoldenSqlListResult(BaseModel):
"""Result object for golden SQL list loading
Provides efficient lookup by SQL ID with dict-like interface.
"""
sql_items: Dict[str, SqlFileItem]
total_count: int
failed_count: int
def get_by_id(self, sql_id: str) -> Optional[SqlFileItem]:
"""Get SQL item by ID
Args:
sql_id: The SQL file ID (filename prefix without extension)
Returns:
SqlFileItem if found, None otherwise
"""
return self.sql_items.get(sql_id)
def get_sql_content(self, sql_id: str) -> Optional[str]:
"""Get SQL content by ID
Args:
sql_id: The SQL file ID (filename prefix without extension)
Returns:
SQL content string if found, None otherwise
"""
item = self.sql_items.get(sql_id)
return item.sql_content if item else None
def list_all_ids(self) -> List[str]:
"""Get list of all SQL IDs
Returns:
List of SQL IDs sorted alphabetically
"""
return sorted(self.sql_items.keys())
def __len__(self) -> int:
"""Return number of successfully loaded SQL files"""
return len(self.sql_items)
def __contains__(self, sql_id: str) -> bool:
"""Check if SQL ID exists"""
return sql_id in self.sql_items
def __iter__(self):
"""Iterate over SQL items"""
return iter(self.sql_items.values())
BENCHMARK_DEFAULT_DB_SCHEMA = "ant_icube_dev."
@@ -36,12 +122,10 @@ class BenchmarkDataConfig(BaseModel):
db_path: str = os.path.join(
BENCHMARK_DATA_ROOT_PATH, f"{BENCHMARK_DEFAULT_DB_SCHEMA}db"
)
table_mapping_file: str = os.path.join(
BENCHMARK_DATA_ROOT_PATH, "table_mapping.json"
)
table_mapping_file: Optional[str] = None
cache_expiry_days: int = 1
repo_url: str = "https://github.com/eosphoros-ai/Falcon"
data_dir: str = "data/source"
data_dirs: List[str] = ["dev_data/dev_databases", "test_data/dev_databases"]
class BenchmarkDataManager(BaseComponent):
@@ -56,7 +140,6 @@ class BenchmarkDataManager(BaseComponent):
self._config = config or BenchmarkDataConfig()
self._http_session: Optional[aiohttp.ClientSession] = None
self._connector: Optional[SQLiteConnector] = None
self._table_mappings = self._load_mappings()
self._lock = asyncio.Lock()
self.temp_dir: Optional[str] = None
@@ -73,9 +156,7 @@ class BenchmarkDataManager(BaseComponent):
async def async_before_stop(self):
try:
logger.info("BenchmarkDataManager: closing resources before stop...")
await self.close()
logger.info("BenchmarkDataManager: close done.")
except Exception as e:
logger.warning(f"BenchmarkDataManager: close failed: {e}")
@@ -120,7 +201,6 @@ class BenchmarkDataManager(BaseComponent):
async def load_data(self):
logger.info("BenchmarkDataManager: start load_data.")
try:
if not self._config.repo_url:
logger.info("BenchmarkDataManager: repo_url not set, skip auto load.")
@@ -132,69 +212,16 @@ class BenchmarkDataManager(BaseComponent):
logger.info(
f"BenchmarkDataManager: auto loading repo {self._config.repo_url} "
f"dir={self._config.data_dir}"
f"dirs={self._config.data_dirs}"
)
await get_benchmark_manager(self.system_app).load_from_github(
repo_url=self._config.repo_url, data_dir=self._config.data_dir
repo_url=self._config.repo_url, data_dirs=self._config.data_dirs
)
self._startup_loaded = True
logger.info("BenchmarkDataManager: auto load finished.")
except Exception as e:
logger.error(f"BenchmarkDataManager: auto load failed: {e}")
def _sanitize_column_name(self, name: str) -> str:
if name is None:
return ""
name = str(name).strip().strip('"').strip("'")
invalid_chars = [
"-",
" ",
".",
",",
";",
":",
"!",
"?",
"'",
'"',
"(",
")",
"[",
"]",
"{",
"}",
"\t",
"\r",
"\n",
"\x00",
]
while name and name[-1] in invalid_chars:
name = name[:-1]
for ch in invalid_chars:
if ch in name:
name = name.replace(ch, "_")
while "__" in name:
name = name.replace("__", "_")
if name and not (name[0].isalpha() or name[0] == "_"):
name = "_" + name
return name.lower()
def _sanitize_and_dedup_headers(self, headers: List[str]) -> List[str]:
sanitized: List[str] = []
used: set = set()
for idx, h in enumerate(headers):
name = self._sanitize_column_name(h)
if not name:
name = f"col_{idx}"
base = name
k = 2
while name in used or not name:
name = f"{base}_{k}"
k += 1
used.add(name)
sanitized.append(name)
return sanitized
# ==========================================================
# 通用查询(阻塞实现,在线程池中调用,支持超时与可中断)
@@ -321,7 +348,9 @@ class BenchmarkDataManager(BaseComponent):
return [dict(zip(cols, row)) for row in rows]
async def load_from_github(
self, repo_url: str, data_dir: str = "data/source"
self,
repo_url: str,
data_dirs: List[str] = ["dev_data/dev_databases", "test_data/dev_databases"],
) -> Dict:
"""Main method to load data from GitHub repository"""
try:
@@ -330,14 +359,26 @@ class BenchmarkDataManager(BaseComponent):
# 1. Download or use cached repository
repo_dir = await self._download_repo_contents(repo_url)
# 2. Find all CSV files recursively
csv_files = self._discover_csv_files(repo_dir, data_dir)
if not csv_files:
raise ValueError("No CSV files found")
logger.info(f"Found {len(csv_files)} CSV files")
# 2. Find all SQLite files recursively in the specified data_dirs
all_sqlite_files = []
for d_dir in data_dirs:
try:
files = self._discover_sqlite_files(repo_dir, d_dir)
logger.info(f"Found {len(files)} SQLite files in {d_dir}")
all_sqlite_files.extend(files)
except ValueError as ve:
# 如果某个目录不存在,记录警告但不中断整个流程
logger.warning(f"Skip directory {d_dir}: {ve}")
# 3. Import to SQLite
result = await self._import_to_database(csv_files)
if not all_sqlite_files:
raise ValueError(
f"No SQLite files found in any of the directories: {data_dirs}"
)
logger.info(f"Total SQLite files to merge: {len(all_sqlite_files)}")
# 3. Merge all SQLite files into the main database
result = await self._merge_sqlite_databases(all_sqlite_files)
return result
except Exception as e:
@@ -346,6 +387,280 @@ class BenchmarkDataManager(BaseComponent):
finally:
self._cleanup_temp_dir()
async def load_file_from_github(self, file_name: Optional[str] = None
) -> Optional[FileLoadResult]:
"""Download and read a specified file from a GitHub repository.
Supported file types: .json / .jsonl
`file_name` can be a relative path within the repository or a plain filename (will be searched recursively).
Unified return structure (FileLoadResult):
- type: "json" | "jsonl"
- file_path, file_name, encoding
- rows: List[{line_no:int, data:Any}] where data is parsed JSON object
- count: total number of rows
- failed_count: number of failed lines (non-zero for jsonl or malformed json)
- failures: details for failed lines
For JSON files:
- If the file contains a JSON array, each element becomes a Row
- If the file contains a single JSON object, it becomes one Row
- The structure is flexible and doesn't depend on specific keys
"""
try:
if not file_name or not str(file_name).strip():
return None
# Download or use cached repository
repo_dir = await self._download_repo_contents(self._config.repo_url)
# Allowed file extensions
allowed_exts = {".jsonl", ".json"}
# Pre-check extension of `file_name` (if provided), otherwise filter by allowed list later
_, requested_ext = os.path.splitext(str(file_name).lower())
if requested_ext and requested_ext not in allowed_exts:
raise ValueError(f"Unsupported file type: {requested_ext}")
# Handle both relative path and plain filename cases
normalized = str(file_name).strip().lstrip("/").replace("\\", os.sep)
candidate_paths: List[str] = []
# Prefer direct path resolution using the relative path
direct_path = os.path.join(repo_dir, normalized)
if os.path.isfile(direct_path):
ext = os.path.splitext(direct_path.lower())[1]
if not requested_ext:
if ext in allowed_exts:
candidate_paths.append(direct_path)
elif ext == requested_ext:
candidate_paths.append(direct_path)
# If not found, recursively search by filename match
if not candidate_paths:
target_name = os.path.basename(normalized)
for root, _, files in os.walk(repo_dir):
for f in files:
if f == target_name:
full = os.path.join(root, f)
ext = os.path.splitext(f.lower())[1]
if not requested_ext:
if ext in allowed_exts:
candidate_paths.append(full)
elif ext == requested_ext:
candidate_paths.append(full)
if not candidate_paths:
raise FileNotFoundError(f"File not found: {file_name}")
# Choose a stable candidate (sorted by path length and lexicographical order)
chosen = sorted(candidate_paths, key=lambda p: (len(p), p))[0]
chosen_ext = os.path.splitext(chosen.lower())[1]
# Build repository-relative path for the file (avoid returning temp local path)
rel_path = os.path.relpath(chosen, repo_dir)
rel_path_posix = rel_path.replace(os.sep, "/")
# Try multiple encodings
encodings = ["utf-8", "iso-8859-1"]
# Handle .json files (array or single object)
if chosen_ext == ".json":
return await self._parse_json_file(
chosen, rel_path_posix, encodings
)
# Handle .jsonl files (line-delimited JSON)
elif chosen_ext == ".jsonl":
return await self._parse_jsonl_file(
chosen, rel_path_posix, encodings
)
else:
raise ValueError(f"Unsupported file extension: {chosen_ext}")
except Exception as e:
logger.error(f"Falcon repository Import failed: {str(e)}")
raise RuntimeError(f"Falcon repository file data loading failed: {e}") from e
finally:
self._cleanup_temp_dir()
async def _parse_json_file(
self, file_path: str, rel_path_posix: str, encodings: List[str]
) -> FileLoadResult:
"""Parse a JSON file (array or single object).
Args:
file_path: Absolute path to the JSON file
rel_path_posix: Repository-relative path in POSIX format
encodings: List of encodings to try
Returns:
FileLoadResult with parsed data
"""
rows: List[Row] = []
failures: List[FailureDetail] = []
used_encoding: Optional[str] = None
# Try reading with different encodings
for enc in encodings:
try:
with open(file_path, "r", encoding=enc) as f:
content = f.read()
try:
data = json.loads(content)
# Handle JSON array
if isinstance(data, list):
for idx, item in enumerate(data, start=1):
rows.append(Row(line_no=idx, data=item))
# Handle single JSON object
elif isinstance(data, dict):
rows.append(Row(line_no=1, data=data))
else:
# Handle primitive types (string, number, etc.)
rows.append(Row(line_no=1, data=data))
used_encoding = enc
break
except json.JSONDecodeError as e:
failures.append(
FailureDetail(
line_no=1,
error=f"JSON decode error: {str(e)}",
line=content[:200],
)
)
used_encoding = enc
break
except UnicodeDecodeError:
continue
except Exception as e:
logger.warning(f"Read json with encoding {enc} failed: {e}")
continue
# Fallback: read as bytes and decode with ASCII ignoring errors
if used_encoding is None:
try:
with open(file_path, "rb") as f:
content = f.read().decode("ascii", errors="ignore")
try:
data = json.loads(content)
if isinstance(data, list):
for idx, item in enumerate(data, start=1):
rows.append(Row(line_no=idx, data=item))
elif isinstance(data, dict):
rows.append(Row(line_no=1, data=data))
else:
rows.append(Row(line_no=1, data=data))
except json.JSONDecodeError as e:
failures.append(
FailureDetail(
line_no=1,
error=f"JSON decode error: {str(e)}",
line=content[:200],
)
)
used_encoding = "ascii-ignore"
except Exception as e:
raise ValueError(f"Failed to read json file: {e}")
return FileLoadResult(
type="json",
file_path=rel_path_posix,
file_name=os.path.basename(file_path),
encoding=used_encoding,
rows=rows,
count=len(rows) + len(failures),
failed_count=len(failures),
failures=failures,
)
async def _parse_jsonl_file(
self, file_path: str, rel_path_posix: str, encodings: List[str]
) -> FileLoadResult:
"""Parse a JSONL file (line-delimited JSON).
Args:
file_path: Absolute path to the JSONL file
rel_path_posix: Repository-relative path in POSIX format
encodings: List of encodings to try
Returns:
FileLoadResult with parsed data
"""
rows: List[Row] = []
failures: List[FailureDetail] = []
used_encoding: Optional[str] = None
# Prefer reading in text mode with multiple encodings
for enc in encodings:
try:
with open(file_path, "r", encoding=enc) as f:
for idx, line in enumerate(f, start=1):
s = line.strip()
if not s:
continue
try:
obj = json.loads(s)
rows.append(Row(line_no=idx, data=obj))
except Exception as e:
failures.append(
FailureDetail(
line_no=idx,
error=str(e),
line=s[:200],
)
)
used_encoding = enc
break
except UnicodeDecodeError:
continue
except Exception as e:
logger.warning(f"Read jsonl with encoding {enc} failed: {e}")
continue
# Fallback: read as bytes and decode with ASCII ignoring errors
if used_encoding is None:
try:
with open(file_path, "rb") as f:
for idx, raw_line in enumerate(f, start=1):
s = raw_line.decode("ascii", errors="ignore").strip()
if not s:
continue
try:
obj = json.loads(s)
rows.append(Row(line_no=idx, data=obj))
except Exception as e:
failures.append(
FailureDetail(
line_no=idx,
error=str(e),
line=s[:200],
)
)
used_encoding = "ascii-ignore"
except Exception as e:
raise ValueError(f"Failed to read jsonl file: {e}")
return FileLoadResult(
type="jsonl",
file_path=rel_path_posix,
file_name=os.path.basename(file_path),
encoding=used_encoding,
rows=rows,
count=(len(rows) + len(failures)),
failed_count=len(failures),
failures=failures,
)
async def get_table_info(self) -> Dict:
"""Get metadata about all tables"""
await self.init_connector()
@@ -445,7 +760,7 @@ class BenchmarkDataManager(BaseComponent):
return (mapped_name or "").lower()
async def _download_repo_contents(self, repo_url: str) -> str:
"""Download repository with caching"""
"""Download repository with caching, supporting branch URLs"""
cache_path = self._get_cache_path(repo_url)
# Use cache if valid
@@ -455,21 +770,45 @@ class BenchmarkDataManager(BaseComponent):
# Download fresh copy
self.temp_dir = tempfile.mkdtemp()
zip_url = (
repo_url.replace("github.com", "api.github.com/repos") + "/zipball/main"
)
# Simple parsing for github.com URLs
github_pattern = r"github\.com/([^/]+)/([^/]+)(?:/tree/(.+))?"
match = re.search(github_pattern, repo_url)
if match:
owner, repo, branch = match.groups()
branch = branch or "main" # Default to main if no tree/branch specified
zip_url = f"https://api.github.com/repos/{owner}/{repo}/zipball/{branch}"
else:
# Fallback for generic structure or direct zip links
if repo_url.endswith(".zip"):
zip_url = repo_url
else:
# Default fallback behavior from original code
zip_url = (
repo_url.replace("github.com", "api.github.com/repos")
+ "/zipball/main"
)
logger.info(f"Downloading from GitHub repo: {zip_url}")
try:
if self._http_session is None:
self._http_session = aiohttp.ClientSession()
async with self._http_session.get(zip_url) as response:
response.raise_for_status()
headers = {"Accept": "application/vnd.github.v3+json"}
async with self._http_session.get(zip_url, headers=headers) as response:
if response.status != 200:
text_resp = await response.text()
raise RuntimeError(
f"GitHub API Error {response.status}: {text_resp}"
)
zip_path = os.path.join(self.temp_dir, "repo.zip")
with open(zip_path, "wb") as f:
while True:
chunk = await response.content.read(1024)
chunk = await response.content.read(1024 * 1024) # 1MB chunks
if not chunk:
break
f.write(chunk)
@@ -479,7 +818,6 @@ class BenchmarkDataManager(BaseComponent):
logger.info(f"Saved repository to cache: {cache_path}")
return self._extract_zip(zip_path)
except Exception as e:
self._cleanup_temp_dir()
raise RuntimeError(f"Failed to download repository: {str(e)}") from e
@@ -515,252 +853,112 @@ class BenchmarkDataManager(BaseComponent):
raise ValueError("No valid directory found after extraction")
return os.path.join(self.temp_dir, extracted_dirs[0])
def _discover_csv_files(self, base_dir: str, search_dir: str) -> List[Dict]:
"""Find all CSV files recursively"""
def _discover_sqlite_files(self, base_dir: str, search_dir: str) -> List[str]:
"""Find all SQLite files recursively in the search directory"""
full_search_dir = os.path.join(base_dir, search_dir) if search_dir else base_dir
if not os.path.exists(full_search_dir):
raise ValueError(f"Directory not found: {full_search_dir}")
csv_files = []
sqlite_files = []
for root, _, files in os.walk(full_search_dir):
for file in files:
if file.lower().endswith(".csv"):
rel_path = os.path.relpath(root, start=base_dir)
csv_files.append(
{
"full_path": os.path.join(root, file),
"rel_path": rel_path,
"file_name": file,
}
)
return csv_files
if file.lower().endswith(".sqlite"):
full_path = os.path.join(root, file)
sqlite_files.append(full_path)
return sqlite_files
async def _import_to_database(self, csv_files: List[Dict]) -> Dict:
"""Import CSV data to SQLite"""
async def _merge_sqlite_databases(self, sqlite_files: List[str]) -> Dict:
"""Merge multiple SQLite files into the main database"""
await self.init_connector()
assert self._connector is not None
results = {
"total_files": len(csv_files),
"successful": 0,
"failed": 0,
"tables_created": [],
}
def _process_one_file(file_info: Dict) -> Tuple[bool, Optional[str]]:
table_name = ""
try:
path_parts = [p for p in file_info["rel_path"].split(os.sep) if p]
table_name = "_".join(path_parts + [Path(file_info["file_name"]).stem])
table_name = self._sanitize_table_name(table_name)
with self._connector.session_scope() as session:
session.execute(text(f'DROP TABLE IF EXISTS "{table_name}"'))
session.commit()
encodings = ["utf-8-sig", "utf-8", "latin-1", "iso-8859-1", "cp1252"]
for encoding in encodings:
try:
with open(file_info["full_path"], "r", encoding=encoding) as f:
content = f.read()
if not content.strip():
raise ValueError("File is empty")
content = content.replace("\r\n", "\n").replace("\r", "\n")
lines = [line for line in content.split("\n") if line.strip()]
if not lines:
raise ValueError("No data after normalization")
header_line = lines[0]
data_line = lines[1] if len(lines) > 1 else ""
try:
sample_for_sniff = "\n".join(lines[:10])
sniffer = csv.Sniffer()
try:
dialect = sniffer.sniff(sample_for_sniff)
except Exception:
# Fallback: choose delimiter by counting common
# separators in header/data line
delims = [",", "\t", ";", "|"]
counts = {
d: (header_line.count(d) if header_line else 0)
+ (data_line.count(d) if data_line else 0)
for d in delims
}
best = (
max(counts, key=counts.get)
if any(counts.values())
else ","
)
class _DefaultDialect(csv.Dialect):
delimiter = best
quotechar = '"'
doublequote = True
skipinitialspace = False
lineterminator = "\n"
quoting = csv.QUOTE_MINIMAL
dialect = _DefaultDialect()
try:
has_header = sniffer.has_header("\n".join(lines[:50]))
except Exception:
has_header = True
header_row = (
list(csv.reader([header_line], dialect))[0]
if header_line
else []
)
first_data_row = (
list(csv.reader([data_line], dialect))[0]
if data_line
else []
)
# Heuristic: if has_header is False but header_row looks
# like names (mostly alphabetic), treat as header
if not has_header:
def _looks_like_header(tokens: List[str]) -> bool:
if not tokens:
return False
# 非空、重复少、字母比例高
cleaned = [
str(t).strip() for t in tokens if str(t).strip()
]
if not cleaned:
return False
# 允许少量数字,但大多以字母开头
alpha_starts = sum(
1
for t in cleaned
if t and (t[0].isalpha() or t[0] == "_")
)
return alpha_starts >= max(
1, int(0.6 * len(cleaned))
)
if _looks_like_header(header_row):
has_header = True
if not has_header:
num_cols_guess = len(header_row)
headers = [f"col_{i}" for i in range(num_cols_guess)]
first_data_row = header_row
else:
headers = header_row
num_cols = (
len(first_data_row) if first_data_row else len(headers)
)
# no header
if not headers or all(
(not str(h).strip()) for h in headers
):
headers = [f"col_{i}" for i in range(num_cols or 1)]
headers = self._sanitize_and_dedup_headers(headers)
if num_cols <= 0:
num_cols = len(headers)
headers = headers[:num_cols]
if not headers or any(
h is None or h == "" for h in headers
):
raise csv.Error("Invalid headers after sanitization")
create_sql = f'''
CREATE TABLE IF NOT EXISTS "{table_name}" (
{", ".join([f'"{h}" TEXT' for h in headers])}
)
'''
insert_sql = f'''
INSERT INTO "{table_name}" ({
", ".join([f'"{h}"' for h in headers])
})
VALUES ({
", ".join([":" + f"p{i}" for i in range(len(headers))])
})
'''
with self._connector.session_scope() as session:
logger.debug(
f"Table: {table_name}, headers(final): {headers}"
)
session.execute(text(create_sql))
reader = csv.reader(lines, dialect)
if has_header:
next(reader, None)
batch_params: List[Dict[str, Any]] = []
for row in reader:
if not row:
continue
if len(row) != len(headers):
if len(row) < len(headers):
row += [None] * (len(headers) - len(row))
else:
row = row[: len(headers)]
params = {
f"p{i}": (row[i] if i < len(row) else None)
for i in range(len(headers))
}
batch_params.append(params)
if len(batch_params) >= 1000:
session.execute(text(insert_sql), batch_params)
batch_params = []
if batch_params:
session.execute(text(insert_sql), batch_params)
session.commit()
return True, table_name
except csv.Error:
self._import_with_simple_split_blocking(table_name, content)
return True, table_name
except UnicodeDecodeError:
continue
except Exception as e:
logger.warning(f"Error with encoding {encoding}: {str(e)}")
continue
def _worker():
results = {
"total_files": len(sqlite_files),
"successful": 0,
"failed": 0,
"tables_merged": [],
}
with self._connector.session_scope() as session:
# 获取底层的 sqlite3 连接对象
connection_proxy = session.connection()
# 兼容不同版本的 SQLAlchemy 获取底层连接的方式
try:
with open(file_info["full_path"], "rb") as f:
content = f.read().decode("ascii", errors="ignore")
if content.strip():
self._import_with_simple_split_blocking(table_name, content)
return True, table_name
else:
raise ValueError("File is empty or unreadable")
except Exception as e:
return (
False,
f"Failed to process {file_info['file_name']}: {str(e)}",
)
# SQLAlchemy 1.4+ / 2.0
raw_conn = connection_proxy.connection.dbapi_connection
except AttributeError:
try:
# 旧版本或某些驱动
raw_conn = connection_proxy.connection
except AttributeError:
# 最后的尝试
raw_conn = session.get_bind().raw_connection()
except Exception as e:
return (
False,
f"Failed to process {file_info.get('full_path', '')}: {str(e)}",
)
# 确保 raw_conn 是 sqlite3 的连接对象
if not raw_conn:
raise RuntimeError("Failed to get raw sqlite3 connection")
for file_info in csv_files:
ok, info = await self._run_in_thread(_process_one_file, file_info)
if ok:
results["successful"] += 1
if info:
results["tables_created"].append(info)
else:
results["failed"] += 1
logger.error(info)
cursor = raw_conn.cursor()
return results
for db_path in sqlite_files:
src_alias = f"src_db_{uuid.uuid4().hex[:8]}"
try:
try:
cursor.execute("PRAGMA database_list")
attached_dbs = cursor.fetchall()
for _, name, _ in attached_dbs:
if name not in ("main", "temp"):
cursor.execute(f"DETACH DATABASE {name}")
except Exception as cleanup_err:
logger.warning(f"Cleanup warning: {cleanup_err}")
cursor.execute(f"ATTACH DATABASE ? AS {src_alias}", (db_path,))
cursor.execute(
f"SELECT name, sql FROM {src_alias}.sqlite_master "
f"WHERE type='table' AND name NOT LIKE 'sqlite_%'"
)
tables = cursor.fetchall()
for table_name, create_sql in tables:
cursor.execute(
"SELECT name FROM sqlite_master "
"WHERE type='table' "
"AND name=?",
(table_name,),
)
if not cursor.fetchone():
cursor.execute(create_sql)
cursor.execute(
f'INSERT INTO main."{table_name}" '
f'SELECT * FROM {src_alias}."{table_name}"'
)
results["tables_merged"].append(table_name)
else:
logger.warning(
f"Table '{table_name}' exists. Skipping."
)
raw_conn.commit()
results["successful"] += 1
except Exception as e:
logger.error(f"Failed to merge {db_path}: {e}")
results["failed"] += 1
try:
raw_conn.rollback()
except Exception:
pass
finally:
try:
cursor.execute(f"DETACH DATABASE {src_alias}")
except Exception:
pass
return results
return await self._run_in_thread(_worker)
def _import_with_simple_split_blocking(self, table_name: str, content: str):
"""Fallback method for malformed CSV files (blocking, 使用 SQLAlchemy 执行)"""