mirror of
https://github.com/csunny/DB-GPT.git
synced 2026-01-13 19:55:44 +00:00
feat(benchmark): update benchmark task use latest falcon github repo questions
This commit is contained in:
@@ -10,6 +10,8 @@ class StorageUtil:
|
||||
|
||||
YUQUE_URL_PREFIX = "https://yuque.com"
|
||||
|
||||
GITHUB_FALCON_PREFIX = "https://github.com/eosphoros-ai/Falcon"
|
||||
|
||||
@staticmethod
|
||||
def get_file_parse_type(file_path: Optional[str]) -> FileParseTypeEnum:
|
||||
"""Get file parsing type based on file path
|
||||
@@ -28,5 +30,7 @@ class StorageUtil:
|
||||
|
||||
if file_path.strip().startswith(StorageUtil.YUQUE_URL_PREFIX):
|
||||
return FileParseTypeEnum.YU_QUE
|
||||
if file_path.strip().startswith(StorageUtil.GITHUB_FALCON_PREFIX):
|
||||
return FileParseTypeEnum.GITHUB
|
||||
|
||||
return FileParseTypeEnum.EXCEL
|
||||
|
||||
@@ -1,11 +1,5 @@
|
||||
from .benchmark_service import BenchmarkService
|
||||
from .data_compare_service import DataCompareService
|
||||
from .file_parse_service import FileParseService
|
||||
from .user_input_execute_service import UserInputExecuteService
|
||||
|
||||
__all__ = [
|
||||
"BenchmarkService",
|
||||
"FileParseService",
|
||||
"UserInputExecuteService",
|
||||
"DataCompareService",
|
||||
]
|
||||
|
||||
@@ -18,6 +18,7 @@ from dbgpt.model import DefaultLLMClient
|
||||
from dbgpt.model.cluster import WorkerManagerFactory
|
||||
from dbgpt.storage.metadata import BaseDao
|
||||
from dbgpt.util import PaginationResult, get_or_create_event_loop
|
||||
from dbgpt.util.benchmarks import StorageUtil
|
||||
from dbgpt_serve.evaluate.service.benchmark.task.benchmark_agent_task import (
|
||||
BenchmarkAgentTask,
|
||||
)
|
||||
@@ -40,7 +41,6 @@ from ...config import ServeConfig
|
||||
from ...models.models import ServeDao, ServeEntity
|
||||
from ..fetchdata.benchmark_data_manager import get_benchmark_manager
|
||||
from .data_compare_service import DataCompareService
|
||||
from .ext.excel_file_parse import ExcelFileParseService
|
||||
from .models import (
|
||||
BaseInputModel,
|
||||
BenchmarkDataSets,
|
||||
@@ -49,7 +49,6 @@ from .models import (
|
||||
BenchmarkModeTypeEnum,
|
||||
BenchmarkTaskResult,
|
||||
ContentTypeEnum,
|
||||
FileParseTypeEnum,
|
||||
InputType,
|
||||
OutputType,
|
||||
)
|
||||
@@ -61,10 +60,7 @@ executor = ThreadPoolExecutor(max_workers=5)
|
||||
|
||||
BENCHMARK_SERVICE_COMPONENT_NAME = "dbgpt_serve_evaluate_benchmark_service"
|
||||
|
||||
STANDARD_BENCHMARK_FILE_PATH = os.path.join(
|
||||
BENCHMARK_DATA_ROOT_PATH,
|
||||
"2025_07_27_public_500_standard_benchmark_question_list.xlsx",
|
||||
)
|
||||
STANDARD_BENCHMARK_FILE_PATH = "https://github.com/eosphoros-ai/Falcon"
|
||||
|
||||
BENCHMARK_OUTPUT_RESULT_PATH = os.path.join(BENCHMARK_DATA_ROOT_PATH, "result")
|
||||
|
||||
@@ -94,11 +90,10 @@ class BenchmarkService(
|
||||
super().__init__(system_app)
|
||||
self.rag_service = get_rag_service(system_app)
|
||||
self.prompt_service = get_prompt_service(system_app)
|
||||
self._file_parse_type = FileParseTypeEnum.EXCEL
|
||||
self._file_parse_type = StorageUtil.get_file_parse_type(STANDARD_BENCHMARK_FILE_PATH)
|
||||
|
||||
fps = ExcelFileParseService()
|
||||
dcs = DataCompareService()
|
||||
self.user_input_execute_service = UserInputExecuteService(fps, dcs)
|
||||
self.user_input_execute_service = UserInputExecuteService(dcs, self._file_parse_type)
|
||||
|
||||
self.trigger_executor = ThreadPoolExecutor(
|
||||
max_workers=5, thread_name_prefix="benchmark-fileWrite"
|
||||
@@ -289,7 +284,7 @@ class BenchmarkService(
|
||||
await manager.load_data()
|
||||
logger.info(
|
||||
f"Benchmark dataset loaded from {manager._config.repo_url} "
|
||||
f"dir={manager._config.data_dir}"
|
||||
f"dir={manager._config.data_dirs}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
|
||||
@@ -0,0 +1,176 @@
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import pandas as pd
|
||||
from openpyxl import Workbook, load_workbook
|
||||
|
||||
from dbgpt.util.benchmarks.ExcelUtils import ExcelUtils
|
||||
|
||||
from ..file_parse_service import FileParseService
|
||||
from ..models import (
|
||||
AnswerExecuteModel,
|
||||
BaseInputModel,
|
||||
BenchmarkDataSets,
|
||||
DataCompareStrategyConfig,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExcelFileParseService(FileParseService):
|
||||
def parse_input_sets(self, path: str) -> BenchmarkDataSets:
|
||||
"""
|
||||
Parse input sets from excel file
|
||||
Args:
|
||||
path: File location path
|
||||
Returns:
|
||||
BenchmarkDataSets: Parsed data sets
|
||||
"""
|
||||
input_stream = self.get_input_stream(path)
|
||||
|
||||
if input_stream is None:
|
||||
raise RuntimeError(f"file not found! path: {path}")
|
||||
|
||||
# Parse excel file to get data sets
|
||||
input_sets = BenchmarkDataSets()
|
||||
workbook = None
|
||||
|
||||
try:
|
||||
workbook = load_workbook(input_stream, data_only=True)
|
||||
input_list = []
|
||||
|
||||
# Get the first worksheet
|
||||
sheet = workbook.worksheets[0]
|
||||
|
||||
for row_num in range(
|
||||
2, sheet.max_row + 1
|
||||
): # Skip header row (start from 1-based index)
|
||||
row = sheet[row_num]
|
||||
if ExcelUtils.is_row_empty(row):
|
||||
continue
|
||||
|
||||
# Get content from columns 1-6 (0-based index 0-5)
|
||||
serial_no_cell = row[0]
|
||||
analysis_model_id_cell = row[1]
|
||||
question_cell = row[2]
|
||||
self_define_tags_cell = row[3]
|
||||
knowledge_cell = row[4]
|
||||
llm_output_cell = row[5]
|
||||
prompt_cell = row[8]
|
||||
|
||||
# Build input model
|
||||
input_model = BaseInputModel(
|
||||
serial_no=int(
|
||||
ExcelUtils.get_cell_value_as_string(serial_no_cell) or "0"
|
||||
),
|
||||
analysis_model_id=ExcelUtils.get_cell_value_as_string(
|
||||
analysis_model_id_cell
|
||||
),
|
||||
question=ExcelUtils.get_cell_value_as_string(question_cell),
|
||||
self_define_tags=ExcelUtils.get_cell_value_as_string(
|
||||
self_define_tags_cell
|
||||
),
|
||||
llm_output=ExcelUtils.get_cell_value_as_string(llm_output_cell),
|
||||
knowledge=ExcelUtils.get_cell_value_as_string(knowledge_cell),
|
||||
prompt=ExcelUtils.get_cell_value_as_string(prompt_cell),
|
||||
)
|
||||
|
||||
input_list.append(input_model)
|
||||
|
||||
input_sets.data_list = input_list
|
||||
except Exception as e:
|
||||
logger.error(f"parse excel error, path: {path}, errorMsg: {e}")
|
||||
finally:
|
||||
try:
|
||||
if workbook is not None:
|
||||
workbook.close()
|
||||
except Exception as e:
|
||||
logger.error(f"close workbook error, path: {path}, errorMsg: {e}")
|
||||
|
||||
return input_sets
|
||||
|
||||
def parse_standard_benchmark_sets(
|
||||
self, standard_excel_path: str
|
||||
) -> List[AnswerExecuteModel]:
|
||||
df = pd.read_excel(standard_excel_path, sheet_name=0)
|
||||
outputs: List[AnswerExecuteModel] = []
|
||||
for _, row in df.iterrows():
|
||||
try:
|
||||
serial_no = int(row["编号"])
|
||||
except Exception:
|
||||
continue
|
||||
question = row.get("用户问题")
|
||||
analysis_model_id = row.get("数据集ID")
|
||||
llm_output = (
|
||||
None if pd.isna(row.get("标准答案SQL")) else str(row.get("标准答案SQL"))
|
||||
)
|
||||
order_by = True
|
||||
if not pd.isna(row.get("是否排序")):
|
||||
try:
|
||||
order_by = bool(int(row.get("是否排序")))
|
||||
except Exception:
|
||||
order_by = True
|
||||
|
||||
std_result: Optional[List[Dict[str, List[str]]]] = None
|
||||
if not pd.isna(row.get("标准结果")):
|
||||
std_result_raw = str(row.get("标准结果"))
|
||||
std_result = self._parse_multi_standard_result(std_result_raw)
|
||||
|
||||
strategy_config = DataCompareStrategyConfig(
|
||||
strategy="CONTAIN_MATCH",
|
||||
order_by=order_by,
|
||||
standard_result=std_result if std_result is not None else None,
|
||||
)
|
||||
outputs.append(
|
||||
AnswerExecuteModel(
|
||||
serialNo=serial_no,
|
||||
analysisModelId=analysis_model_id,
|
||||
question=question,
|
||||
llmOutput=llm_output,
|
||||
executeResult=std_result,
|
||||
strategyConfig=strategy_config,
|
||||
)
|
||||
)
|
||||
return outputs
|
||||
|
||||
def _parse_multi_standard_result(
|
||||
self, std_result_raw: str
|
||||
) -> Optional[List[Dict[str, List[str]]]]:
|
||||
"""
|
||||
Parse multiple standard results from raw string data.
|
||||
|
||||
Handles multiple results separated by newlines and parses each line as a dict.
|
||||
|
||||
Args:
|
||||
std_result_raw (str): Raw standard result string with multiple lines
|
||||
|
||||
Returns:
|
||||
Optional[List[Dict[str, List[str]]]]: List of parsed dictionaries,
|
||||
or None if parsing fails or no valid data
|
||||
"""
|
||||
try:
|
||||
std_result_raw = std_result_raw.strip()
|
||||
if not std_result_raw:
|
||||
return None
|
||||
|
||||
# 处理多个结果,通过换行符分隔
|
||||
result_lines = std_result_raw.split("\n")
|
||||
result_list = []
|
||||
|
||||
for line in result_lines:
|
||||
line = line.strip()
|
||||
if line:
|
||||
try:
|
||||
result_list.append(json.loads(line))
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to parse line as JSON: {line}, error: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
return result_list if result_list else None
|
||||
except Exception as e:
|
||||
logger.error(f"parse multiple standard results error: {e}")
|
||||
return None
|
||||
@@ -0,0 +1,572 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from dbgpt.util import get_or_create_event_loop
|
||||
from dbgpt_serve.evaluate.service.benchmark.file_parse_service import FileParseService
|
||||
from dbgpt_serve.evaluate.service.benchmark.models import (
|
||||
BaseInputModel,
|
||||
BenchmarkDataSets, AnswerExecuteModel, DataCompareStrategyConfig,
|
||||
)
|
||||
from dbgpt_serve.evaluate.service.fetchdata.benchmark_data_manager import (
|
||||
FileLoadResult,
|
||||
GoldenSqlListResult,
|
||||
get_benchmark_manager,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TEXT_SQL_PROMPT = """
|
||||
Given the following dataset, including field names and sampling information:
|
||||
{Schema}
|
||||
|
||||
Containing the following knowledge:
|
||||
{Knowledge}
|
||||
|
||||
Based on the above information, please generate the SQL for the following question:
|
||||
{Query}
|
||||
|
||||
[Output Requirements]
|
||||
* Field names in the generated SQL must use the actual field names from the table schema.
|
||||
* Table names in the generated SQL must use the actual table names provided in the schema.
|
||||
* Physical field names in the generated SQL must originate from the corresponding physical tables; generating fields that do not exist in the table is not allowed.
|
||||
* Cartesian product calculations are not allowed in the generated SQL. This includes `CROSS JOIN`, `JOIN` operations missing `ON` or `USING` conditions, and multi-table joins without conditions set for all relationships between tables.
|
||||
* The generated SQL must strictly adhere to {dialect} syntax. If it does not comply with this syntax, please regenerate it.
|
||||
* Output only pure, executable SQL without any additional information.
|
||||
|
||||
[Example]
|
||||
** Table 1 Information
|
||||
*** Table Name: orders
|
||||
*** DDL Statement:
|
||||
CREATE TABLE "orders" (
|
||||
"order_id" TEXT,
|
||||
"customer_id" TEXT,
|
||||
"item_type" TEXT,
|
||||
"order_date" DATE
|
||||
)
|
||||
*** Field Information:
|
||||
|Field Name|Field Type|Sample Data|
|
||||
|:--:|:--:|:--:|
|
||||
|order_id|text|CN-2025-2562,CN-2025-8623,CN-2025-6535|
|
||||
|customer_id|text|52351,56263,71252|
|
||||
|item_type|text|办公用品,设备,家具|
|
||||
|order_date|date|2023-02-21,2024-03-30,2024-12-20|
|
||||
|
||||
User Question: 请帮我统计下各商品类型的订单数
|
||||
Final Output SQL:
|
||||
SELECT
|
||||
item_type,
|
||||
COUNT(*) as order_count
|
||||
FROM
|
||||
orders
|
||||
GROUP BY
|
||||
item_type;
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchmarkDataItem:
|
||||
"""benchmark data info item"""
|
||||
|
||||
question_id: int
|
||||
db_id: str
|
||||
question: str
|
||||
sql: str
|
||||
answer: List[Dict[str, List[str]]]
|
||||
is_order: str
|
||||
|
||||
@staticmethod
|
||||
def from_dict(data: dict) -> "BenchmarkDataItem":
|
||||
return BenchmarkDataItem(
|
||||
question_id=data.get("question_id", 0),
|
||||
db_id=str(data.get("db_id", "")),
|
||||
question=str(data.get("question", "")),
|
||||
sql=str(data.get("SQL", "")),
|
||||
answer=data.get("answer", []),
|
||||
is_order=data.get("is_order", "0"),
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class ColumnItem:
|
||||
"""column info Item"""
|
||||
|
||||
column_id: int
|
||||
column_name: str
|
||||
column_type: str
|
||||
sample_values: list
|
||||
|
||||
@staticmethod
|
||||
def from_dict(data: dict) -> "ColumnItem":
|
||||
"""从字典创建 ColumnItem 实例
|
||||
|
||||
Args:
|
||||
data: 包含列信息的字典
|
||||
|
||||
Returns:
|
||||
ColumnItem: 列信息实例
|
||||
"""
|
||||
return ColumnItem(
|
||||
column_id=data.get("column_id", 0),
|
||||
column_name=data.get("column_name", ""),
|
||||
column_type=data.get("column_type", ""),
|
||||
sample_values=data.get("sample_values", []),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TableDDLItem:
|
||||
"""Table DDL Info Item"""
|
||||
|
||||
table_id: int
|
||||
table_name: str
|
||||
columns: List[ColumnItem]
|
||||
ddl: Optional[str] = None
|
||||
|
||||
@staticmethod
|
||||
def from_dict(data: dict) -> "TableDDLItem":
|
||||
"""从字典创建 TableDDLItem 实例
|
||||
|
||||
Args:
|
||||
data: 包含表信息的字典
|
||||
|
||||
Returns:
|
||||
TableDDLItem: 表信息实例
|
||||
"""
|
||||
columns_data = data.get("columns", [])
|
||||
columns = [ColumnItem.from_dict(col) for col in columns_data]
|
||||
|
||||
return TableDDLItem(
|
||||
table_id=data.get("table_id", 0),
|
||||
table_name=data.get("table_name", ""),
|
||||
columns=columns,
|
||||
ddl=data.get("ddl"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TableDataItem:
|
||||
"""Table Data Info Item"""
|
||||
|
||||
db_id: str
|
||||
table_ddl: List[TableDDLItem]
|
||||
|
||||
@staticmethod
|
||||
def from_dict(data: dict) -> "TableDataItem":
|
||||
"""从字典创建 TableDataItem 实例
|
||||
|
||||
Args:
|
||||
data: 包含数据库表信息的字典
|
||||
|
||||
Returns:
|
||||
TableDataItem: 数据库表信息实例
|
||||
"""
|
||||
tables_data = data.get("tables", [])
|
||||
table_ddl = [TableDDLItem.from_dict(table) for table in tables_data]
|
||||
|
||||
return TableDataItem(
|
||||
db_id=str(data.get("db_id", "")),
|
||||
table_ddl=table_ddl,
|
||||
)
|
||||
|
||||
class SafeDict(dict):
|
||||
def __missing__(self, key):
|
||||
return '{' + key + '}'
|
||||
|
||||
|
||||
class FalconFileParseService(FileParseService):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._dev_data_file = "dev_data/dev.json"
|
||||
self._dev_table_ddl_file = "dev_data/tables.json"
|
||||
|
||||
self.benchmark_manager = get_benchmark_manager()
|
||||
|
||||
self._dev_data: Optional[FileLoadResult] = None
|
||||
self._dev_table_ddl: Optional[FileLoadResult] = None
|
||||
self._data_loaded = False
|
||||
|
||||
@staticmethod
|
||||
def _format_answer_list(answer_list: List[Dict[str, List[str]]]) -> str:
|
||||
"""格式化 answer 列表为字符串
|
||||
|
||||
Args:
|
||||
answer_list: 答案列表,每个元素是字典,字典的值是字符串列表
|
||||
|
||||
Returns:
|
||||
str: JSON 格式的字符串,如果列表为空则返回空字符串
|
||||
"""
|
||||
if not answer_list:
|
||||
return ""
|
||||
|
||||
try:
|
||||
import json
|
||||
# 将答案列表转换为 JSON 字符串,每个答案一行
|
||||
return "\n".join(json.dumps(item, ensure_ascii=False) for item in answer_list)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to format answer list: {e}")
|
||||
return str(answer_list)
|
||||
|
||||
def _ensure_data_loaded(self):
|
||||
if self._data_loaded:
|
||||
return
|
||||
logger.info("Loading benchmark data for the first time...")
|
||||
try:
|
||||
self._dev_data, self._dev_table_ddl = self._load_data_sync()
|
||||
self._data_loaded = True
|
||||
logger.info("Benchmark data loaded successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load benchmark data: {e}", exc_info=True)
|
||||
raise RuntimeError(f"Failed to load benchmark data: {e}")
|
||||
|
||||
|
||||
def parse_input_sets(self, path: str) -> BenchmarkDataSets:
|
||||
"""
|
||||
Parse input sets from github repo
|
||||
Args:
|
||||
path: File URL path
|
||||
Returns:
|
||||
BenchmarkDataSets: Parsed data sets
|
||||
"""
|
||||
self._ensure_data_loaded()
|
||||
|
||||
try:
|
||||
# 1. 解析评测数据
|
||||
benchmark_data_list = self._parse_benchmark_data(self._dev_data)
|
||||
if not benchmark_data_list:
|
||||
logger.error("Failed to parse benchmark data")
|
||||
return BenchmarkDataSets(data_list=[])
|
||||
|
||||
# 2. 解析表结构
|
||||
table_ddl_data_list = self._parse_table_ddl_data(self._dev_table_ddl)
|
||||
if not table_ddl_data_list:
|
||||
logger.error("Failed to parse talbe ddl data")
|
||||
return BenchmarkDataSets(data_list=[])
|
||||
table_ddl_data_map = {x.db_id: x.table_ddl for x in table_ddl_data_list}
|
||||
|
||||
# 3. 将问题数据转换为 BaseInputModel 列表,并关联标准答案
|
||||
input_models = []
|
||||
for idx, question_item in enumerate(benchmark_data_list, start=1):
|
||||
input_model = BaseInputModel(
|
||||
serial_no=question_item.question_id,
|
||||
analysis_model_id=question_item.db_id,
|
||||
question=question_item.question,
|
||||
self_define_tags="",
|
||||
knowledge="",
|
||||
llm_output=self._format_answer_list(question_item.answer),
|
||||
prompt=self.load_benchmark_prompt_template(question_item, table_ddl_data_map.get(question_item.db_id)),
|
||||
)
|
||||
input_models.append(input_model)
|
||||
logger.info(f"Successfully parsed {len(input_models)} question items")
|
||||
return BenchmarkDataSets(data_list=input_models)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"load remote benchmark data error, error: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
return BenchmarkDataSets(data_list=[])
|
||||
|
||||
def parse_standard_benchmark_sets(
|
||||
self, standard_excel_path: str
|
||||
) -> List[AnswerExecuteModel]:
|
||||
|
||||
self._ensure_data_loaded()
|
||||
|
||||
outputs: List[AnswerExecuteModel] = []
|
||||
# 1. 解析评测数据
|
||||
benchmark_data_list = self._parse_benchmark_data(self._dev_data)
|
||||
if not benchmark_data_list:
|
||||
logger.error("Failed to parse benchmark data")
|
||||
return outputs
|
||||
|
||||
for idx, question_item in enumerate(benchmark_data_list, start=1):
|
||||
serial_no = question_item.question_id
|
||||
question = question_item.question
|
||||
analysis_model_id = question_item.db_id
|
||||
llm_output = question_item.sql
|
||||
order_by = True
|
||||
if question_item.is_order:
|
||||
try:
|
||||
order_by = bool(int(question_item.is_order))
|
||||
except Exception:
|
||||
order_by = True
|
||||
|
||||
std_result: Optional[List[Dict[str, List[str]]]] = None
|
||||
if question_item.answer:
|
||||
std_result = self._parse_multi_standard_result(question_item.answer)
|
||||
|
||||
strategy_config = DataCompareStrategyConfig(
|
||||
strategy="CONTAIN_MATCH",
|
||||
order_by=order_by,
|
||||
standard_result=std_result if std_result is not None else None,
|
||||
)
|
||||
outputs.append(
|
||||
AnswerExecuteModel(
|
||||
serialNo=serial_no,
|
||||
analysisModelId=analysis_model_id,
|
||||
question=question,
|
||||
llmOutput=llm_output,
|
||||
executeResult=std_result,
|
||||
strategyConfig=strategy_config,
|
||||
)
|
||||
)
|
||||
return outputs
|
||||
|
||||
def _parse_benchmark_data(
|
||||
self, benchmark_data: Optional[FileLoadResult]
|
||||
) -> Optional[List[BenchmarkDataItem]]:
|
||||
"""
|
||||
解析问题数据
|
||||
Args:
|
||||
benchmark_data: 从 GitHub 加载的问题文件数据
|
||||
Returns:
|
||||
List[BenchmarkDataItem]: 问题数据列表,如果解析失败返回 None
|
||||
"""
|
||||
if not benchmark_data or not benchmark_data.rows:
|
||||
return None
|
||||
|
||||
if benchmark_data.failed_count > 0:
|
||||
logger.warning(
|
||||
f"Question data has {benchmark_data.failed_count} failed rows"
|
||||
)
|
||||
|
||||
benchmark_data_list = []
|
||||
for row in benchmark_data.rows:
|
||||
if not isinstance(row.data, dict):
|
||||
logger.warning(
|
||||
f"Row {row.line_no} data is not a dict: {type(row.data)}"
|
||||
)
|
||||
continue
|
||||
|
||||
benchmark_data_item = BenchmarkDataItem.from_dict(row.data)
|
||||
|
||||
if (
|
||||
not benchmark_data_item.question_id
|
||||
or not benchmark_data_item.question
|
||||
or not benchmark_data_item.db_id
|
||||
):
|
||||
logger.warning(
|
||||
f"Row {row.line_no} missing required fields: "
|
||||
f"question_id={benchmark_data_item.question_id}, "
|
||||
f"question={benchmark_data_item.question}, "
|
||||
f"db_id={benchmark_data_item.db_id}"
|
||||
)
|
||||
continue
|
||||
|
||||
benchmark_data_list.append(benchmark_data_item)
|
||||
|
||||
if not benchmark_data_list:
|
||||
logger.error("No valid benchmark data parsed")
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
f"Successfully parsed {len(benchmark_data_list)} benchmark data"
|
||||
f" from {len(benchmark_data.rows)} rows"
|
||||
)
|
||||
# TODO 临时只取前 100 条数据
|
||||
return benchmark_data_list[:100]
|
||||
|
||||
|
||||
def _parse_table_ddl_data(
|
||||
self, table_ddl_data: Optional[FileLoadResult]
|
||||
) -> Optional[List[TableDataItem]]:
|
||||
"""
|
||||
解析表 DDL 数据
|
||||
Args:
|
||||
table_ddl_data: 从 GitHub 加载的表 DDL 数据
|
||||
Returns:
|
||||
List[TableDataItem]: 表 DDL 数据列表,如果解析失败返回 None
|
||||
"""
|
||||
if not table_ddl_data or not table_ddl_data.rows:
|
||||
logger.warning("table ddl data is None")
|
||||
return None
|
||||
if table_ddl_data.failed_count > 0:
|
||||
logger.warning(
|
||||
f"table ddl data has {table_ddl_data.failed_count} failed items"
|
||||
)
|
||||
|
||||
table_ddl_data_list = []
|
||||
for row in table_ddl_data.rows:
|
||||
if not isinstance(row.data, dict):
|
||||
logger.warning(
|
||||
f"Row {row.line_no} data is not a dict: {type(row.data)}"
|
||||
)
|
||||
continue
|
||||
|
||||
table_ddl_data_item = TableDataItem.from_dict(row.data)
|
||||
|
||||
if (
|
||||
not table_ddl_data_item.db_id
|
||||
or not table_ddl_data_item.table_ddl
|
||||
):
|
||||
logger.warning(
|
||||
f"Row {row.line_no} missing required fields: "
|
||||
f"db_id={table_ddl_data_item.db_id}, "
|
||||
f"table_ddl={table_ddl_data_item.table_ddl}"
|
||||
)
|
||||
continue
|
||||
|
||||
table_ddl_data_list.append(table_ddl_data_item)
|
||||
|
||||
if not table_ddl_data_list:
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
f"Successfully parsed {len(table_ddl_data_list)} table DDL data"
|
||||
f" from {len(table_ddl_data.rows)} rows"
|
||||
)
|
||||
return table_ddl_data_list
|
||||
|
||||
|
||||
async def _async_load_data(
|
||||
self,
|
||||
) -> Tuple[
|
||||
Optional[FileLoadResult],
|
||||
Optional[FileLoadResult],
|
||||
]:
|
||||
"""并发加载两个文件数据
|
||||
|
||||
使用 asyncio.gather 并发执行两个异步任务,提高加载效率
|
||||
|
||||
Returns:
|
||||
Tuple: (dev_data, dev_table_ddl)
|
||||
"""
|
||||
dev_data, dev_table_ddl = await asyncio.gather(
|
||||
self.benchmark_manager.load_file_from_github(self._dev_data_file),
|
||||
self.benchmark_manager.load_file_from_github(self._dev_table_ddl_file),
|
||||
)
|
||||
return dev_data, dev_table_ddl
|
||||
|
||||
def _load_data_sync(
|
||||
self,
|
||||
) -> Tuple[
|
||||
Optional[FileLoadResult],
|
||||
Optional[FileLoadResult],
|
||||
]:
|
||||
"""在同步上下文中加载数据
|
||||
|
||||
智能检测当前事件循环状态:
|
||||
- 如果事件循环正在运行,使用线程池在新线程中执行异步代码
|
||||
- 如果没有运行中的事件循环,直接使用 run_until_complete
|
||||
|
||||
Returns:
|
||||
Tuple: (dev_data, dev_table_ddl)
|
||||
"""
|
||||
try:
|
||||
# 尝试获取当前运行中的事件循环
|
||||
asyncio.get_running_loop()
|
||||
# 如果到这里说明有运行中的循环,使用线程池执行
|
||||
import concurrent.futures
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(asyncio.run, self._async_load_data())
|
||||
return future.result()
|
||||
except RuntimeError:
|
||||
# 没有运行中的循环,正常执行
|
||||
loop = get_or_create_event_loop()
|
||||
return loop.run_until_complete(self._async_load_data())
|
||||
|
||||
def _run_in_new_loop(
|
||||
self,
|
||||
) -> Tuple[
|
||||
Optional[FileLoadResult],
|
||||
Optional[FileLoadResult],
|
||||
]:
|
||||
"""
|
||||
在新的事件循环中运行异步任务
|
||||
|
||||
Returns:
|
||||
Tuple: (dev_data, dev_table_ddl)
|
||||
"""
|
||||
new_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(new_loop)
|
||||
try:
|
||||
return new_loop.run_until_complete(self._async_load_data())
|
||||
finally:
|
||||
new_loop.close()
|
||||
|
||||
@staticmethod
|
||||
def _build_table_schema_info(table_ddl_list: Optional[List[TableDDLItem]]) -> str:
|
||||
"""构建表的 Schema 信息
|
||||
|
||||
Args:
|
||||
table_ddl_list: 表 DDL 信息列表
|
||||
|
||||
Returns:
|
||||
str: 格式化的表 Schema 信息字符串
|
||||
"""
|
||||
if not table_ddl_list:
|
||||
return ""
|
||||
|
||||
schema_parts = []
|
||||
|
||||
for idx, table in enumerate(table_ddl_list, start=1):
|
||||
# 表头信息
|
||||
schema_parts.append(f"** Table {idx} Information")
|
||||
schema_parts.append(f"*** Table Name: {table.table_name}")
|
||||
|
||||
# 如果有 DDL,添加 DDL 信息
|
||||
if table.ddl:
|
||||
schema_parts.append(f"*** DDL Statement:")
|
||||
# DDL 可能是多行的,需要缩进处理
|
||||
ddl_lines = table.ddl.strip().split('\n')
|
||||
for ddl_line in ddl_lines:
|
||||
schema_parts.append(f" {ddl_line}")
|
||||
|
||||
# 列信息表头 - 使用更清晰的格式
|
||||
schema_parts.append("*** Field Information:")
|
||||
schema_parts.append("|Field Name|Field Type|Sample Data|")
|
||||
schema_parts.append("|:--:|:--:|:--:|")
|
||||
|
||||
# 添加每一列的信息
|
||||
for column in table.columns:
|
||||
# 格式化样本值 - 限制在合理长度内
|
||||
if column.sample_values:
|
||||
# 取前3-5个样本值,用逗号连接,避免过长
|
||||
sample_count = min(3, len(column.sample_values))
|
||||
sample_str = ",".join(str(val) for val in column.sample_values[:sample_count])
|
||||
# 如果样本值过长,截断
|
||||
if len(sample_str) > 100:
|
||||
sample_str = sample_str[:97] + "..."
|
||||
else:
|
||||
sample_str = "-"
|
||||
|
||||
schema_parts.append(f"|{column.column_name}|{column.column_type}|{sample_str}|")
|
||||
|
||||
# 表之间添加空行分隔
|
||||
if idx < len(table_ddl_list):
|
||||
schema_parts.append("")
|
||||
|
||||
return "\n".join(schema_parts)
|
||||
|
||||
def _parse_multi_standard_result(
|
||||
self, answer_list: List[Dict[str, List[str]]]
|
||||
) -> Optional[List[Dict[str, List[str]]]]:
|
||||
"""
|
||||
解析标准答案结果
|
||||
|
||||
Args:
|
||||
answer_list: 答案列表,已经是正确的格式
|
||||
|
||||
Returns:
|
||||
Optional[List[Dict[str, List[str]]]]: 返回答案列表,如果为空返回 None
|
||||
"""
|
||||
try:
|
||||
if not answer_list:
|
||||
return None
|
||||
return answer_list if answer_list else None
|
||||
except Exception as e:
|
||||
logger.error(f"parse standard results error: {e}")
|
||||
return None
|
||||
|
||||
def load_benchmark_prompt_template(self, question_item: BenchmarkDataItem, table_ddl: List[TableDDLItem]) -> str:
|
||||
"""
|
||||
build benchmark prompt template
|
||||
"""
|
||||
schema = self._build_table_schema_info(table_ddl)
|
||||
format_params = {
|
||||
"Schema": schema,
|
||||
"Knowledge": "",
|
||||
"Query": question_item.question
|
||||
}
|
||||
# 使用 SafeDict 和 format_map 实现非严格模式,缺失的变量不会报错
|
||||
return TEXT_SQL_PROMPT.format_map(SafeDict(format_params))
|
||||
@@ -4,20 +4,16 @@ import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import List, Optional, Any, Dict
|
||||
|
||||
import pandas as pd
|
||||
from openpyxl import Workbook, load_workbook
|
||||
|
||||
from dbgpt.util.benchmarks.ExcelUtils import ExcelUtils
|
||||
from dbgpt_serve.evaluate.db.benchmark_db import BenchmarkResultDao
|
||||
|
||||
from .models import (
|
||||
AnswerExecuteModel,
|
||||
BaseInputModel,
|
||||
BenchmarkDataSets,
|
||||
BenchmarkExecuteConfig,
|
||||
DataCompareStrategyConfig,
|
||||
OutputType,
|
||||
RoundAnswerConfirmModel,
|
||||
)
|
||||
@@ -27,8 +23,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class FileParseService(ABC):
|
||||
def __init__(self):
|
||||
self._benchmark_dao = BenchmarkResultDao()
|
||||
|
||||
# export column configuration file path
|
||||
self._column_config_file_path = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
@@ -56,7 +50,6 @@ class FileParseService(ABC):
|
||||
data.append(AnswerExecuteModel.from_dict(obj))
|
||||
return data
|
||||
|
||||
@abstractmethod
|
||||
def write_data_compare_result(
|
||||
self,
|
||||
path: str,
|
||||
@@ -65,15 +58,100 @@ class FileParseService(ABC):
|
||||
is_execute: bool,
|
||||
llm_count: int,
|
||||
):
|
||||
"""Write compare results to File
|
||||
"""Write compare results to an Excel file
|
||||
|
||||
Args:
|
||||
path: Output file path
|
||||
round_id: Round ID
|
||||
confirm_models: List of answer confirm models
|
||||
is_execute: Whether to execute the comparison
|
||||
llm_count: LLM count
|
||||
The output Excel file will be named as '<base>.xlsx' and
|
||||
sheet name is 'benchmark_compare_result'. If the file exists, it will
|
||||
append rows; otherwise it will create a new file with headers.
|
||||
"""
|
||||
try:
|
||||
# Ensure output directory exists
|
||||
output_dir = Path(path).parent
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
output_file = path
|
||||
|
||||
headers = [
|
||||
"serialNo",
|
||||
"analysisModelId",
|
||||
"question",
|
||||
"selfDefineTags",
|
||||
"prompt",
|
||||
"standardAnswerSql",
|
||||
"standardAnswer",
|
||||
"llmCode",
|
||||
"llmOutput",
|
||||
"executeResult",
|
||||
"errorMsg",
|
||||
"compareResult",
|
||||
]
|
||||
|
||||
# Load or create workbook and sheet
|
||||
if Path(output_file).exists():
|
||||
workbook = load_workbook(str(output_file))
|
||||
if "benchmark_compare_result" in workbook.sheetnames:
|
||||
worksheet = workbook["benchmark_compare_result"]
|
||||
else:
|
||||
worksheet = workbook.create_sheet("benchmark_compare_result")
|
||||
# Write headers if new sheet
|
||||
for col_idx, header in enumerate(headers, 1):
|
||||
worksheet.cell(row=1, column=col_idx, value=header)
|
||||
else:
|
||||
workbook = Workbook()
|
||||
worksheet = workbook.active
|
||||
worksheet.title = "benchmark_compare_result"
|
||||
# Write headers
|
||||
for col_idx, header in enumerate(headers, 1):
|
||||
worksheet.cell(row=1, column=col_idx, value=header)
|
||||
|
||||
# Determine start row to append
|
||||
start_row = worksheet.max_row + 1 if worksheet.max_row else 2
|
||||
|
||||
# Append rows
|
||||
for idx, cm in enumerate(confirm_models):
|
||||
row_data = [
|
||||
cm.serialNo,
|
||||
cm.analysisModelId,
|
||||
cm.question,
|
||||
cm.selfDefineTags,
|
||||
cm.prompt,
|
||||
cm.standardAnswerSql,
|
||||
self._format_set_result(cm.strategyConfig.standard_result)
|
||||
if cm.strategyConfig is not None
|
||||
else "",
|
||||
cm.llmCode,
|
||||
cm.llmOutput,
|
||||
json.dumps(cm.executeResult, ensure_ascii=False)
|
||||
if cm.executeResult is not None
|
||||
else "",
|
||||
cm.errorMsg,
|
||||
cm.compareResult.value if cm.compareResult else None,
|
||||
]
|
||||
for col_idx, value in enumerate(row_data, 1):
|
||||
worksheet.cell(row=start_row + idx, column=col_idx, value=value)
|
||||
|
||||
# Autosize columns (simple strategy)
|
||||
for column in worksheet.columns:
|
||||
max_length = 0
|
||||
column_letter = column[0].column_letter
|
||||
for cell in column:
|
||||
try:
|
||||
if cell.value and len(str(cell.value)) > max_length:
|
||||
max_length = len(str(cell.value))
|
||||
except Exception:
|
||||
pass
|
||||
adjusted_width = min(max(max_length + 2, 10), 80)
|
||||
worksheet.column_dimensions[column_letter].width = adjusted_width
|
||||
|
||||
workbook.save(str(output_file))
|
||||
workbook.close()
|
||||
logger.info(
|
||||
f"[write_data_compare_result] compare written to Excel: {output_file}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[write_data_compare_result] write excel error for path={path}: {e}"
|
||||
)
|
||||
|
||||
def summary_and_write_multi_round_benchmark_result(
|
||||
self, output_path: str, round_id: int
|
||||
@@ -163,7 +241,6 @@ class FileParseService(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def write_multi_round_benchmark_result(
|
||||
self,
|
||||
output_file_path: str,
|
||||
@@ -186,149 +263,6 @@ class FileParseService(ABC):
|
||||
start_index: Starting index (batch start row index)
|
||||
offset: Offset(file rows offset)
|
||||
"""
|
||||
|
||||
|
||||
class ExcelFileParseService(FileParseService):
|
||||
def parse_input_sets(self, path: str) -> BenchmarkDataSets:
|
||||
"""
|
||||
Parse input sets from excel file
|
||||
Args:
|
||||
path: File location path
|
||||
Returns:
|
||||
BenchmarkDataSets: Parsed data sets
|
||||
"""
|
||||
input_stream = self.get_input_stream(path)
|
||||
|
||||
if input_stream is None:
|
||||
raise RuntimeError(f"file not found! path: {path}")
|
||||
|
||||
# Parse excel file to get data sets
|
||||
input_sets = BenchmarkDataSets()
|
||||
workbook = None
|
||||
|
||||
try:
|
||||
workbook = load_workbook(input_stream, data_only=True)
|
||||
input_list = []
|
||||
|
||||
# Get the first worksheet
|
||||
sheet = workbook.worksheets[0]
|
||||
|
||||
for row_num in range(
|
||||
2, sheet.max_row + 1
|
||||
): # Skip header row (start from 1-based index)
|
||||
row = sheet[row_num]
|
||||
if ExcelUtils.is_row_empty(row):
|
||||
continue
|
||||
|
||||
# Get content from columns 1-6 (0-based index 0-5)
|
||||
serial_no_cell = row[0]
|
||||
analysis_model_id_cell = row[1]
|
||||
question_cell = row[2]
|
||||
self_define_tags_cell = row[3]
|
||||
knowledge_cell = row[4]
|
||||
llm_output_cell = row[5]
|
||||
prompt_cell = row[8]
|
||||
|
||||
# Build input model
|
||||
input_model = BaseInputModel(
|
||||
serial_no=int(
|
||||
ExcelUtils.get_cell_value_as_string(serial_no_cell) or "0"
|
||||
),
|
||||
analysis_model_id=ExcelUtils.get_cell_value_as_string(
|
||||
analysis_model_id_cell
|
||||
),
|
||||
question=ExcelUtils.get_cell_value_as_string(question_cell),
|
||||
self_define_tags=ExcelUtils.get_cell_value_as_string(
|
||||
self_define_tags_cell
|
||||
),
|
||||
llm_output=ExcelUtils.get_cell_value_as_string(llm_output_cell),
|
||||
knowledge=ExcelUtils.get_cell_value_as_string(knowledge_cell),
|
||||
prompt=ExcelUtils.get_cell_value_as_string(prompt_cell),
|
||||
)
|
||||
|
||||
input_list.append(input_model)
|
||||
|
||||
input_sets.data_list = input_list
|
||||
except Exception as e:
|
||||
logger.error(f"parse excel error, path: {path}, errorMsg: {e}")
|
||||
finally:
|
||||
try:
|
||||
if workbook is not None:
|
||||
workbook.close()
|
||||
except Exception as e:
|
||||
logger.error(f"close workbook error, path: {path}, errorMsg: {e}")
|
||||
|
||||
return input_sets
|
||||
|
||||
def parse_standard_benchmark_sets(
|
||||
self, standard_excel_path: str
|
||||
) -> List[AnswerExecuteModel]:
|
||||
df = pd.read_excel(standard_excel_path, sheet_name=0)
|
||||
outputs: List[AnswerExecuteModel] = []
|
||||
for _, row in df.iterrows():
|
||||
try:
|
||||
serial_no = int(row["编号"])
|
||||
except Exception:
|
||||
continue
|
||||
question = row.get("用户问题")
|
||||
analysis_model_id = row.get("数据集ID")
|
||||
llm_output = (
|
||||
None if pd.isna(row.get("标准答案SQL")) else str(row.get("标准答案SQL"))
|
||||
)
|
||||
order_by = True
|
||||
if not pd.isna(row.get("是否排序")):
|
||||
try:
|
||||
order_by = bool(int(row.get("是否排序")))
|
||||
except Exception:
|
||||
order_by = True
|
||||
|
||||
std_result: Optional[List[Dict[str, List[str]]]] = None
|
||||
if not pd.isna(row.get("标准结果")):
|
||||
std_result_raw = str(row.get("标准结果"))
|
||||
std_result = self._parse_multi_standard_result(std_result_raw)
|
||||
|
||||
strategy_config = DataCompareStrategyConfig(
|
||||
strategy="CONTAIN_MATCH",
|
||||
order_by=order_by,
|
||||
standard_result=std_result if std_result is not None else None,
|
||||
)
|
||||
outputs.append(
|
||||
AnswerExecuteModel(
|
||||
serialNo=serial_no,
|
||||
analysisModelId=analysis_model_id,
|
||||
question=question,
|
||||
llmOutput=llm_output,
|
||||
executeResult=std_result,
|
||||
strategyConfig=strategy_config,
|
||||
)
|
||||
)
|
||||
return outputs
|
||||
|
||||
def write_multi_round_benchmark_result(
|
||||
self,
|
||||
output_file_path: str,
|
||||
round_id: int,
|
||||
config: BenchmarkExecuteConfig,
|
||||
inputs: List[BaseInputModel],
|
||||
outputs: List[OutputType],
|
||||
start_index: int,
|
||||
offset: int,
|
||||
) -> bool:
|
||||
"""
|
||||
Write the benchmark Result to Excel File With Multi Round
|
||||
|
||||
Args:
|
||||
output_file_path: Output file path
|
||||
round_id: Round ID
|
||||
config: Benchmark configuration
|
||||
inputs: List of input data
|
||||
outputs: List of output data
|
||||
start_index: Starting index (batch start row index)
|
||||
offset: Offset(file rows offset)
|
||||
|
||||
Returns:
|
||||
bool: Returns True if write is successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
# 确保输出目录存在
|
||||
output_dir = Path(output_file_path).parent
|
||||
@@ -436,109 +370,6 @@ class ExcelFileParseService(FileParseService):
|
||||
logger.error(f"write excel file error: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def write_data_compare_result(
|
||||
self,
|
||||
path: str,
|
||||
round_id: int,
|
||||
confirm_models: List[RoundAnswerConfirmModel],
|
||||
is_execute: bool,
|
||||
llm_count: int,
|
||||
):
|
||||
"""Write compare results to an Excel file
|
||||
|
||||
The output Excel file will be named as '<base>.xlsx' and
|
||||
sheet name is 'benchmark_compare_result'. If the file exists, it will
|
||||
append rows; otherwise it will create a new file with headers.
|
||||
"""
|
||||
try:
|
||||
# Ensure output directory exists
|
||||
output_dir = Path(path).parent
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
output_file = path
|
||||
|
||||
headers = [
|
||||
"serialNo",
|
||||
"analysisModelId",
|
||||
"question",
|
||||
"selfDefineTags",
|
||||
"prompt",
|
||||
"standardAnswerSql",
|
||||
"standardAnswer",
|
||||
"llmCode",
|
||||
"llmOutput",
|
||||
"executeResult",
|
||||
"errorMsg",
|
||||
"compareResult",
|
||||
]
|
||||
|
||||
# Load or create workbook and sheet
|
||||
if Path(output_file).exists():
|
||||
workbook = load_workbook(str(output_file))
|
||||
if "benchmark_compare_result" in workbook.sheetnames:
|
||||
worksheet = workbook["benchmark_compare_result"]
|
||||
else:
|
||||
worksheet = workbook.create_sheet("benchmark_compare_result")
|
||||
# Write headers if new sheet
|
||||
for col_idx, header in enumerate(headers, 1):
|
||||
worksheet.cell(row=1, column=col_idx, value=header)
|
||||
else:
|
||||
workbook = Workbook()
|
||||
worksheet = workbook.active
|
||||
worksheet.title = "benchmark_compare_result"
|
||||
# Write headers
|
||||
for col_idx, header in enumerate(headers, 1):
|
||||
worksheet.cell(row=1, column=col_idx, value=header)
|
||||
|
||||
# Determine start row to append
|
||||
start_row = worksheet.max_row + 1 if worksheet.max_row else 2
|
||||
|
||||
# Append rows
|
||||
for idx, cm in enumerate(confirm_models):
|
||||
row_data = [
|
||||
cm.serialNo,
|
||||
cm.analysisModelId,
|
||||
cm.question,
|
||||
cm.selfDefineTags,
|
||||
cm.prompt,
|
||||
cm.standardAnswerSql,
|
||||
self._format_set_result(cm.strategyConfig.standard_result)
|
||||
if cm.strategyConfig is not None
|
||||
else "",
|
||||
cm.llmCode,
|
||||
cm.llmOutput,
|
||||
json.dumps(cm.executeResult, ensure_ascii=False)
|
||||
if cm.executeResult is not None
|
||||
else "",
|
||||
cm.errorMsg,
|
||||
cm.compareResult.value if cm.compareResult else None,
|
||||
]
|
||||
for col_idx, value in enumerate(row_data, 1):
|
||||
worksheet.cell(row=start_row + idx, column=col_idx, value=value)
|
||||
|
||||
# Autosize columns (simple strategy)
|
||||
for column in worksheet.columns:
|
||||
max_length = 0
|
||||
column_letter = column[0].column_letter
|
||||
for cell in column:
|
||||
try:
|
||||
if cell.value and len(str(cell.value)) > max_length:
|
||||
max_length = len(str(cell.value))
|
||||
except Exception:
|
||||
pass
|
||||
adjusted_width = min(max(max_length + 2, 10), 80)
|
||||
worksheet.column_dimensions[column_letter].width = adjusted_width
|
||||
|
||||
workbook.save(str(output_file))
|
||||
workbook.close()
|
||||
logger.info(
|
||||
f"[write_data_compare_result] compare written to Excel: {output_file}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[write_data_compare_result] write excel error for path={path}: {e}"
|
||||
)
|
||||
|
||||
def _get_value_by_source_type(
|
||||
self,
|
||||
field: str,
|
||||
@@ -636,46 +467,6 @@ class ExcelFileParseService(FileParseService):
|
||||
else:
|
||||
return str(value) if value is not None else ""
|
||||
|
||||
def _parse_multi_standard_result(
|
||||
self, std_result_raw: str
|
||||
) -> Optional[List[Dict[str, List[str]]]]:
|
||||
"""
|
||||
Parse multiple standard results from raw string data.
|
||||
|
||||
Handles multiple results separated by newlines and parses each line as a dict.
|
||||
|
||||
Args:
|
||||
std_result_raw (str): Raw standard result string with multiple lines
|
||||
|
||||
Returns:
|
||||
Optional[List[Dict[str, List[str]]]]: List of parsed dictionaries,
|
||||
or None if parsing fails or no valid data
|
||||
"""
|
||||
try:
|
||||
std_result_raw = std_result_raw.strip()
|
||||
if not std_result_raw:
|
||||
return None
|
||||
|
||||
# 处理多个结果,通过换行符分隔
|
||||
result_lines = std_result_raw.split("\n")
|
||||
result_list = []
|
||||
|
||||
for line in result_lines:
|
||||
line = line.strip()
|
||||
if line:
|
||||
try:
|
||||
result_list.append(json.loads(line))
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to parse line as JSON: {line}, error: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
return result_list if result_list else None
|
||||
except Exception as e:
|
||||
logger.error(f"parse multiple standard results error: {e}")
|
||||
return None
|
||||
|
||||
def _format_set_result(
|
||||
self, sql_result: List[Dict[str, List[str]]]
|
||||
) -> Optional[str]:
|
||||
|
||||
@@ -4,7 +4,6 @@ import logging
|
||||
import os
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from dbgpt.util.benchmarks import StorageUtil
|
||||
from dbgpt_serve.evaluate.db.benchmark_db import BenchmarkResultDao
|
||||
from dbgpt_serve.evaluate.service.fetchdata.benchmark_data_manager import (
|
||||
BENCHMARK_DEFAULT_DB_SCHEMA,
|
||||
@@ -12,6 +11,8 @@ from dbgpt_serve.evaluate.service.fetchdata.benchmark_data_manager import (
|
||||
)
|
||||
|
||||
from .data_compare_service import DataCompareService
|
||||
from .ext.excel_file_parse import ExcelFileParseService
|
||||
from .ext.falcon_file_parse import FalconFileParseService
|
||||
from .file_parse_service import FileParseService
|
||||
from .models import (
|
||||
AnswerExecuteModel,
|
||||
@@ -33,14 +34,28 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class UserInputExecuteService:
|
||||
def __init__(
|
||||
self, file_service: FileParseService, compare_service: DataCompareService
|
||||
self, compare_service: DataCompareService, file_type: FileParseTypeEnum = None
|
||||
):
|
||||
self.file_service = file_service
|
||||
self.compare_service = compare_service
|
||||
self.file_service = self.file_service(file_type)
|
||||
|
||||
self.dao = BenchmarkResultDao()
|
||||
# sql query timeout in seconds
|
||||
self.query_timeout = float(os.getenv("BENCHMARK_SQL_TIMEOUT", 360.0))
|
||||
|
||||
def file_service(self, file_type: FileParseTypeEnum) -> FileParseService:
|
||||
"""
|
||||
Get file service instance based on file type.
|
||||
|
||||
Returns:
|
||||
FileParseService: File service instance
|
||||
"""
|
||||
if file_type == FileParseTypeEnum.GITHUB:
|
||||
return FalconFileParseService()
|
||||
elif file_type == FileParseTypeEnum.EXCEL:
|
||||
return ExcelFileParseService()
|
||||
raise NotImplementedError(f"filePraseType: {file_type} is not implemented yet")
|
||||
|
||||
def read_input_file(
|
||||
self, input_file_path: str
|
||||
) -> Union[List[BaseInputModel], None]:
|
||||
@@ -53,15 +68,10 @@ class UserInputExecuteService:
|
||||
Returns:
|
||||
List[BaseInputModel]: Input data list
|
||||
"""
|
||||
file_parse_type: FileParseTypeEnum = StorageUtil.get_file_parse_type(
|
||||
input_sets: BenchmarkDataSets = self.file_service.parse_input_sets(
|
||||
input_file_path
|
||||
)
|
||||
if file_parse_type == FileParseTypeEnum.EXCEL:
|
||||
input_sets: BenchmarkDataSets = self.file_service.parse_input_sets(
|
||||
input_file_path
|
||||
)
|
||||
return input_sets.data_list
|
||||
return None
|
||||
return input_sets.data_list
|
||||
|
||||
def post_dispatch(
|
||||
self,
|
||||
@@ -225,14 +235,13 @@ class UserInputExecuteService:
|
||||
)
|
||||
|
||||
results = json.loads(summary_json) if summary_json else []
|
||||
dao = BenchmarkResultDao()
|
||||
for item in results:
|
||||
llm_code = item.get("llmCode")
|
||||
right = int(item.get("right", 0))
|
||||
wrong = int(item.get("wrong", 0))
|
||||
failed = int(item.get("failed", 0))
|
||||
exception = int(item.get("exception", 0))
|
||||
dao.upsert_summary(
|
||||
self.dao.upsert_summary(
|
||||
round_id,
|
||||
location,
|
||||
llm_code,
|
||||
|
||||
@@ -4,15 +4,17 @@ import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
import zipfile
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from concurrent.futures import TimeoutError as FutureTimeoutError
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||
from typing import Any, Dict, List, Optional, Tuple, cast, Literal
|
||||
|
||||
import aiohttp
|
||||
from sqlalchemy import text
|
||||
@@ -24,6 +26,90 @@ from dbgpt_ext.datasource.rdbms.conn_sqlite import SQLiteConnector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---- Unified model result definitions for load_file_from_github ----
|
||||
class FailureDetail(BaseModel):
|
||||
line_no: int
|
||||
error: str
|
||||
line: str
|
||||
|
||||
class Row(BaseModel):
|
||||
line_no: int
|
||||
data: Any
|
||||
|
||||
class FileLoadResult(BaseModel):
|
||||
type: Literal["jsonl", "json", "text"]
|
||||
file_path: str
|
||||
file_name: str
|
||||
encoding: Optional[str] = None
|
||||
rows: List[Row]
|
||||
count: int
|
||||
failed_count: int
|
||||
failures: List[FailureDetail] = []
|
||||
|
||||
|
||||
class SqlFileItem(BaseModel):
|
||||
"""Represents a single SQL file with its ID and content"""
|
||||
|
||||
sql_id: str
|
||||
sql_content: str
|
||||
file_path: str
|
||||
file_name: str
|
||||
encoding: Optional[str] = None
|
||||
|
||||
|
||||
class GoldenSqlListResult(BaseModel):
|
||||
"""Result object for golden SQL list loading
|
||||
|
||||
Provides efficient lookup by SQL ID with dict-like interface.
|
||||
"""
|
||||
sql_items: Dict[str, SqlFileItem]
|
||||
total_count: int
|
||||
failed_count: int
|
||||
|
||||
def get_by_id(self, sql_id: str) -> Optional[SqlFileItem]:
|
||||
"""Get SQL item by ID
|
||||
|
||||
Args:
|
||||
sql_id: The SQL file ID (filename prefix without extension)
|
||||
|
||||
Returns:
|
||||
SqlFileItem if found, None otherwise
|
||||
"""
|
||||
return self.sql_items.get(sql_id)
|
||||
|
||||
def get_sql_content(self, sql_id: str) -> Optional[str]:
|
||||
"""Get SQL content by ID
|
||||
|
||||
Args:
|
||||
sql_id: The SQL file ID (filename prefix without extension)
|
||||
|
||||
Returns:
|
||||
SQL content string if found, None otherwise
|
||||
"""
|
||||
item = self.sql_items.get(sql_id)
|
||||
return item.sql_content if item else None
|
||||
|
||||
def list_all_ids(self) -> List[str]:
|
||||
"""Get list of all SQL IDs
|
||||
|
||||
Returns:
|
||||
List of SQL IDs sorted alphabetically
|
||||
"""
|
||||
return sorted(self.sql_items.keys())
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return number of successfully loaded SQL files"""
|
||||
return len(self.sql_items)
|
||||
|
||||
def __contains__(self, sql_id: str) -> bool:
|
||||
"""Check if SQL ID exists"""
|
||||
return sql_id in self.sql_items
|
||||
|
||||
def __iter__(self):
|
||||
"""Iterate over SQL items"""
|
||||
return iter(self.sql_items.values())
|
||||
|
||||
|
||||
BENCHMARK_DEFAULT_DB_SCHEMA = "ant_icube_dev."
|
||||
|
||||
|
||||
@@ -36,12 +122,10 @@ class BenchmarkDataConfig(BaseModel):
|
||||
db_path: str = os.path.join(
|
||||
BENCHMARK_DATA_ROOT_PATH, f"{BENCHMARK_DEFAULT_DB_SCHEMA}db"
|
||||
)
|
||||
table_mapping_file: str = os.path.join(
|
||||
BENCHMARK_DATA_ROOT_PATH, "table_mapping.json"
|
||||
)
|
||||
table_mapping_file: Optional[str] = None
|
||||
cache_expiry_days: int = 1
|
||||
repo_url: str = "https://github.com/eosphoros-ai/Falcon"
|
||||
data_dir: str = "data/source"
|
||||
data_dirs: List[str] = ["dev_data/dev_databases", "test_data/dev_databases"]
|
||||
|
||||
|
||||
class BenchmarkDataManager(BaseComponent):
|
||||
@@ -56,7 +140,6 @@ class BenchmarkDataManager(BaseComponent):
|
||||
self._config = config or BenchmarkDataConfig()
|
||||
self._http_session: Optional[aiohttp.ClientSession] = None
|
||||
self._connector: Optional[SQLiteConnector] = None
|
||||
self._table_mappings = self._load_mappings()
|
||||
self._lock = asyncio.Lock()
|
||||
self.temp_dir: Optional[str] = None
|
||||
|
||||
@@ -73,9 +156,7 @@ class BenchmarkDataManager(BaseComponent):
|
||||
|
||||
async def async_before_stop(self):
|
||||
try:
|
||||
logger.info("BenchmarkDataManager: closing resources before stop...")
|
||||
await self.close()
|
||||
logger.info("BenchmarkDataManager: close done.")
|
||||
except Exception as e:
|
||||
logger.warning(f"BenchmarkDataManager: close failed: {e}")
|
||||
|
||||
@@ -120,7 +201,6 @@ class BenchmarkDataManager(BaseComponent):
|
||||
|
||||
async def load_data(self):
|
||||
logger.info("BenchmarkDataManager: start load_data.")
|
||||
|
||||
try:
|
||||
if not self._config.repo_url:
|
||||
logger.info("BenchmarkDataManager: repo_url not set, skip auto load.")
|
||||
@@ -132,69 +212,16 @@ class BenchmarkDataManager(BaseComponent):
|
||||
|
||||
logger.info(
|
||||
f"BenchmarkDataManager: auto loading repo {self._config.repo_url} "
|
||||
f"dir={self._config.data_dir}"
|
||||
f"dirs={self._config.data_dirs}"
|
||||
)
|
||||
await get_benchmark_manager(self.system_app).load_from_github(
|
||||
repo_url=self._config.repo_url, data_dir=self._config.data_dir
|
||||
repo_url=self._config.repo_url, data_dirs=self._config.data_dirs
|
||||
)
|
||||
self._startup_loaded = True
|
||||
logger.info("BenchmarkDataManager: auto load finished.")
|
||||
except Exception as e:
|
||||
logger.error(f"BenchmarkDataManager: auto load failed: {e}")
|
||||
|
||||
def _sanitize_column_name(self, name: str) -> str:
|
||||
if name is None:
|
||||
return ""
|
||||
name = str(name).strip().strip('"').strip("'")
|
||||
invalid_chars = [
|
||||
"-",
|
||||
" ",
|
||||
".",
|
||||
",",
|
||||
";",
|
||||
":",
|
||||
"!",
|
||||
"?",
|
||||
"'",
|
||||
'"',
|
||||
"(",
|
||||
")",
|
||||
"[",
|
||||
"]",
|
||||
"{",
|
||||
"}",
|
||||
"\t",
|
||||
"\r",
|
||||
"\n",
|
||||
"\x00",
|
||||
]
|
||||
while name and name[-1] in invalid_chars:
|
||||
name = name[:-1]
|
||||
for ch in invalid_chars:
|
||||
if ch in name:
|
||||
name = name.replace(ch, "_")
|
||||
while "__" in name:
|
||||
name = name.replace("__", "_")
|
||||
if name and not (name[0].isalpha() or name[0] == "_"):
|
||||
name = "_" + name
|
||||
return name.lower()
|
||||
|
||||
def _sanitize_and_dedup_headers(self, headers: List[str]) -> List[str]:
|
||||
sanitized: List[str] = []
|
||||
used: set = set()
|
||||
for idx, h in enumerate(headers):
|
||||
name = self._sanitize_column_name(h)
|
||||
if not name:
|
||||
name = f"col_{idx}"
|
||||
base = name
|
||||
k = 2
|
||||
while name in used or not name:
|
||||
name = f"{base}_{k}"
|
||||
k += 1
|
||||
used.add(name)
|
||||
sanitized.append(name)
|
||||
return sanitized
|
||||
|
||||
# ==========================================================
|
||||
|
||||
# 通用查询(阻塞实现,在线程池中调用,支持超时与可中断)
|
||||
@@ -321,7 +348,9 @@ class BenchmarkDataManager(BaseComponent):
|
||||
return [dict(zip(cols, row)) for row in rows]
|
||||
|
||||
async def load_from_github(
|
||||
self, repo_url: str, data_dir: str = "data/source"
|
||||
self,
|
||||
repo_url: str,
|
||||
data_dirs: List[str] = ["dev_data/dev_databases", "test_data/dev_databases"],
|
||||
) -> Dict:
|
||||
"""Main method to load data from GitHub repository"""
|
||||
try:
|
||||
@@ -330,14 +359,26 @@ class BenchmarkDataManager(BaseComponent):
|
||||
# 1. Download or use cached repository
|
||||
repo_dir = await self._download_repo_contents(repo_url)
|
||||
|
||||
# 2. Find all CSV files recursively
|
||||
csv_files = self._discover_csv_files(repo_dir, data_dir)
|
||||
if not csv_files:
|
||||
raise ValueError("No CSV files found")
|
||||
logger.info(f"Found {len(csv_files)} CSV files")
|
||||
# 2. Find all SQLite files recursively in the specified data_dirs
|
||||
all_sqlite_files = []
|
||||
for d_dir in data_dirs:
|
||||
try:
|
||||
files = self._discover_sqlite_files(repo_dir, d_dir)
|
||||
logger.info(f"Found {len(files)} SQLite files in {d_dir}")
|
||||
all_sqlite_files.extend(files)
|
||||
except ValueError as ve:
|
||||
# 如果某个目录不存在,记录警告但不中断整个流程
|
||||
logger.warning(f"Skip directory {d_dir}: {ve}")
|
||||
|
||||
# 3. Import to SQLite
|
||||
result = await self._import_to_database(csv_files)
|
||||
if not all_sqlite_files:
|
||||
raise ValueError(
|
||||
f"No SQLite files found in any of the directories: {data_dirs}"
|
||||
)
|
||||
|
||||
logger.info(f"Total SQLite files to merge: {len(all_sqlite_files)}")
|
||||
|
||||
# 3. Merge all SQLite files into the main database
|
||||
result = await self._merge_sqlite_databases(all_sqlite_files)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
@@ -346,6 +387,280 @@ class BenchmarkDataManager(BaseComponent):
|
||||
finally:
|
||||
self._cleanup_temp_dir()
|
||||
|
||||
async def load_file_from_github(self, file_name: Optional[str] = None
|
||||
) -> Optional[FileLoadResult]:
|
||||
"""Download and read a specified file from a GitHub repository.
|
||||
|
||||
Supported file types: .json / .jsonl
|
||||
`file_name` can be a relative path within the repository or a plain filename (will be searched recursively).
|
||||
|
||||
Unified return structure (FileLoadResult):
|
||||
- type: "json" | "jsonl"
|
||||
- file_path, file_name, encoding
|
||||
- rows: List[{line_no:int, data:Any}] where data is parsed JSON object
|
||||
- count: total number of rows
|
||||
- failed_count: number of failed lines (non-zero for jsonl or malformed json)
|
||||
- failures: details for failed lines
|
||||
|
||||
For JSON files:
|
||||
- If the file contains a JSON array, each element becomes a Row
|
||||
- If the file contains a single JSON object, it becomes one Row
|
||||
- The structure is flexible and doesn't depend on specific keys
|
||||
"""
|
||||
try:
|
||||
if not file_name or not str(file_name).strip():
|
||||
return None
|
||||
|
||||
# Download or use cached repository
|
||||
repo_dir = await self._download_repo_contents(self._config.repo_url)
|
||||
|
||||
# Allowed file extensions
|
||||
allowed_exts = {".jsonl", ".json"}
|
||||
|
||||
# Pre-check extension of `file_name` (if provided), otherwise filter by allowed list later
|
||||
_, requested_ext = os.path.splitext(str(file_name).lower())
|
||||
if requested_ext and requested_ext not in allowed_exts:
|
||||
raise ValueError(f"Unsupported file type: {requested_ext}")
|
||||
|
||||
# Handle both relative path and plain filename cases
|
||||
normalized = str(file_name).strip().lstrip("/").replace("\\", os.sep)
|
||||
candidate_paths: List[str] = []
|
||||
|
||||
# Prefer direct path resolution using the relative path
|
||||
direct_path = os.path.join(repo_dir, normalized)
|
||||
if os.path.isfile(direct_path):
|
||||
ext = os.path.splitext(direct_path.lower())[1]
|
||||
if not requested_ext:
|
||||
if ext in allowed_exts:
|
||||
candidate_paths.append(direct_path)
|
||||
elif ext == requested_ext:
|
||||
candidate_paths.append(direct_path)
|
||||
|
||||
# If not found, recursively search by filename match
|
||||
if not candidate_paths:
|
||||
target_name = os.path.basename(normalized)
|
||||
for root, _, files in os.walk(repo_dir):
|
||||
for f in files:
|
||||
if f == target_name:
|
||||
full = os.path.join(root, f)
|
||||
ext = os.path.splitext(f.lower())[1]
|
||||
if not requested_ext:
|
||||
if ext in allowed_exts:
|
||||
candidate_paths.append(full)
|
||||
elif ext == requested_ext:
|
||||
candidate_paths.append(full)
|
||||
|
||||
if not candidate_paths:
|
||||
raise FileNotFoundError(f"File not found: {file_name}")
|
||||
|
||||
# Choose a stable candidate (sorted by path length and lexicographical order)
|
||||
chosen = sorted(candidate_paths, key=lambda p: (len(p), p))[0]
|
||||
chosen_ext = os.path.splitext(chosen.lower())[1]
|
||||
|
||||
# Build repository-relative path for the file (avoid returning temp local path)
|
||||
rel_path = os.path.relpath(chosen, repo_dir)
|
||||
rel_path_posix = rel_path.replace(os.sep, "/")
|
||||
|
||||
# Try multiple encodings
|
||||
encodings = ["utf-8", "iso-8859-1"]
|
||||
|
||||
# Handle .json files (array or single object)
|
||||
if chosen_ext == ".json":
|
||||
return await self._parse_json_file(
|
||||
chosen, rel_path_posix, encodings
|
||||
)
|
||||
|
||||
# Handle .jsonl files (line-delimited JSON)
|
||||
elif chosen_ext == ".jsonl":
|
||||
return await self._parse_jsonl_file(
|
||||
chosen, rel_path_posix, encodings
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported file extension: {chosen_ext}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Falcon repository Import failed: {str(e)}")
|
||||
raise RuntimeError(f"Falcon repository file data loading failed: {e}") from e
|
||||
finally:
|
||||
self._cleanup_temp_dir()
|
||||
|
||||
async def _parse_json_file(
|
||||
self, file_path: str, rel_path_posix: str, encodings: List[str]
|
||||
) -> FileLoadResult:
|
||||
"""Parse a JSON file (array or single object).
|
||||
|
||||
Args:
|
||||
file_path: Absolute path to the JSON file
|
||||
rel_path_posix: Repository-relative path in POSIX format
|
||||
encodings: List of encodings to try
|
||||
|
||||
Returns:
|
||||
FileLoadResult with parsed data
|
||||
"""
|
||||
rows: List[Row] = []
|
||||
failures: List[FailureDetail] = []
|
||||
used_encoding: Optional[str] = None
|
||||
|
||||
# Try reading with different encodings
|
||||
for enc in encodings:
|
||||
try:
|
||||
with open(file_path, "r", encoding=enc) as f:
|
||||
content = f.read()
|
||||
|
||||
try:
|
||||
data = json.loads(content)
|
||||
|
||||
# Handle JSON array
|
||||
if isinstance(data, list):
|
||||
for idx, item in enumerate(data, start=1):
|
||||
rows.append(Row(line_no=idx, data=item))
|
||||
# Handle single JSON object
|
||||
elif isinstance(data, dict):
|
||||
rows.append(Row(line_no=1, data=data))
|
||||
else:
|
||||
# Handle primitive types (string, number, etc.)
|
||||
rows.append(Row(line_no=1, data=data))
|
||||
|
||||
used_encoding = enc
|
||||
break
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
failures.append(
|
||||
FailureDetail(
|
||||
line_no=1,
|
||||
error=f"JSON decode error: {str(e)}",
|
||||
line=content[:200],
|
||||
)
|
||||
)
|
||||
used_encoding = enc
|
||||
break
|
||||
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning(f"Read json with encoding {enc} failed: {e}")
|
||||
continue
|
||||
|
||||
# Fallback: read as bytes and decode with ASCII ignoring errors
|
||||
if used_encoding is None:
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
content = f.read().decode("ascii", errors="ignore")
|
||||
|
||||
try:
|
||||
data = json.loads(content)
|
||||
|
||||
if isinstance(data, list):
|
||||
for idx, item in enumerate(data, start=1):
|
||||
rows.append(Row(line_no=idx, data=item))
|
||||
elif isinstance(data, dict):
|
||||
rows.append(Row(line_no=1, data=data))
|
||||
else:
|
||||
rows.append(Row(line_no=1, data=data))
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
failures.append(
|
||||
FailureDetail(
|
||||
line_no=1,
|
||||
error=f"JSON decode error: {str(e)}",
|
||||
line=content[:200],
|
||||
)
|
||||
)
|
||||
|
||||
used_encoding = "ascii-ignore"
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to read json file: {e}")
|
||||
|
||||
return FileLoadResult(
|
||||
type="json",
|
||||
file_path=rel_path_posix,
|
||||
file_name=os.path.basename(file_path),
|
||||
encoding=used_encoding,
|
||||
rows=rows,
|
||||
count=len(rows) + len(failures),
|
||||
failed_count=len(failures),
|
||||
failures=failures,
|
||||
)
|
||||
|
||||
async def _parse_jsonl_file(
|
||||
self, file_path: str, rel_path_posix: str, encodings: List[str]
|
||||
) -> FileLoadResult:
|
||||
"""Parse a JSONL file (line-delimited JSON).
|
||||
|
||||
Args:
|
||||
file_path: Absolute path to the JSONL file
|
||||
rel_path_posix: Repository-relative path in POSIX format
|
||||
encodings: List of encodings to try
|
||||
|
||||
Returns:
|
||||
FileLoadResult with parsed data
|
||||
"""
|
||||
rows: List[Row] = []
|
||||
failures: List[FailureDetail] = []
|
||||
used_encoding: Optional[str] = None
|
||||
|
||||
# Prefer reading in text mode with multiple encodings
|
||||
for enc in encodings:
|
||||
try:
|
||||
with open(file_path, "r", encoding=enc) as f:
|
||||
for idx, line in enumerate(f, start=1):
|
||||
s = line.strip()
|
||||
if not s:
|
||||
continue
|
||||
try:
|
||||
obj = json.loads(s)
|
||||
rows.append(Row(line_no=idx, data=obj))
|
||||
except Exception as e:
|
||||
failures.append(
|
||||
FailureDetail(
|
||||
line_no=idx,
|
||||
error=str(e),
|
||||
line=s[:200],
|
||||
)
|
||||
)
|
||||
used_encoding = enc
|
||||
break
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning(f"Read jsonl with encoding {enc} failed: {e}")
|
||||
continue
|
||||
|
||||
# Fallback: read as bytes and decode with ASCII ignoring errors
|
||||
if used_encoding is None:
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
for idx, raw_line in enumerate(f, start=1):
|
||||
s = raw_line.decode("ascii", errors="ignore").strip()
|
||||
if not s:
|
||||
continue
|
||||
try:
|
||||
obj = json.loads(s)
|
||||
rows.append(Row(line_no=idx, data=obj))
|
||||
except Exception as e:
|
||||
failures.append(
|
||||
FailureDetail(
|
||||
line_no=idx,
|
||||
error=str(e),
|
||||
line=s[:200],
|
||||
)
|
||||
)
|
||||
used_encoding = "ascii-ignore"
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to read jsonl file: {e}")
|
||||
|
||||
return FileLoadResult(
|
||||
type="jsonl",
|
||||
file_path=rel_path_posix,
|
||||
file_name=os.path.basename(file_path),
|
||||
encoding=used_encoding,
|
||||
rows=rows,
|
||||
count=(len(rows) + len(failures)),
|
||||
failed_count=len(failures),
|
||||
failures=failures,
|
||||
)
|
||||
|
||||
async def get_table_info(self) -> Dict:
|
||||
"""Get metadata about all tables"""
|
||||
await self.init_connector()
|
||||
@@ -445,7 +760,7 @@ class BenchmarkDataManager(BaseComponent):
|
||||
return (mapped_name or "").lower()
|
||||
|
||||
async def _download_repo_contents(self, repo_url: str) -> str:
|
||||
"""Download repository with caching"""
|
||||
"""Download repository with caching, supporting branch URLs"""
|
||||
cache_path = self._get_cache_path(repo_url)
|
||||
|
||||
# Use cache if valid
|
||||
@@ -455,21 +770,45 @@ class BenchmarkDataManager(BaseComponent):
|
||||
|
||||
# Download fresh copy
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
zip_url = (
|
||||
repo_url.replace("github.com", "api.github.com/repos") + "/zipball/main"
|
||||
)
|
||||
|
||||
# Simple parsing for github.com URLs
|
||||
github_pattern = r"github\.com/([^/]+)/([^/]+)(?:/tree/(.+))?"
|
||||
match = re.search(github_pattern, repo_url)
|
||||
|
||||
if match:
|
||||
owner, repo, branch = match.groups()
|
||||
branch = branch or "main" # Default to main if no tree/branch specified
|
||||
zip_url = f"https://api.github.com/repos/{owner}/{repo}/zipball/{branch}"
|
||||
else:
|
||||
# Fallback for generic structure or direct zip links
|
||||
if repo_url.endswith(".zip"):
|
||||
zip_url = repo_url
|
||||
else:
|
||||
# Default fallback behavior from original code
|
||||
zip_url = (
|
||||
repo_url.replace("github.com", "api.github.com/repos")
|
||||
+ "/zipball/main"
|
||||
)
|
||||
|
||||
logger.info(f"Downloading from GitHub repo: {zip_url}")
|
||||
|
||||
try:
|
||||
if self._http_session is None:
|
||||
self._http_session = aiohttp.ClientSession()
|
||||
async with self._http_session.get(zip_url) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
headers = {"Accept": "application/vnd.github.v3+json"}
|
||||
async with self._http_session.get(zip_url, headers=headers) as response:
|
||||
if response.status != 200:
|
||||
text_resp = await response.text()
|
||||
raise RuntimeError(
|
||||
f"GitHub API Error {response.status}: {text_resp}"
|
||||
)
|
||||
|
||||
zip_path = os.path.join(self.temp_dir, "repo.zip")
|
||||
|
||||
with open(zip_path, "wb") as f:
|
||||
while True:
|
||||
chunk = await response.content.read(1024)
|
||||
chunk = await response.content.read(1024 * 1024) # 1MB chunks
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
@@ -479,7 +818,6 @@ class BenchmarkDataManager(BaseComponent):
|
||||
logger.info(f"Saved repository to cache: {cache_path}")
|
||||
|
||||
return self._extract_zip(zip_path)
|
||||
|
||||
except Exception as e:
|
||||
self._cleanup_temp_dir()
|
||||
raise RuntimeError(f"Failed to download repository: {str(e)}") from e
|
||||
@@ -515,252 +853,112 @@ class BenchmarkDataManager(BaseComponent):
|
||||
raise ValueError("No valid directory found after extraction")
|
||||
return os.path.join(self.temp_dir, extracted_dirs[0])
|
||||
|
||||
def _discover_csv_files(self, base_dir: str, search_dir: str) -> List[Dict]:
|
||||
"""Find all CSV files recursively"""
|
||||
def _discover_sqlite_files(self, base_dir: str, search_dir: str) -> List[str]:
|
||||
"""Find all SQLite files recursively in the search directory"""
|
||||
full_search_dir = os.path.join(base_dir, search_dir) if search_dir else base_dir
|
||||
if not os.path.exists(full_search_dir):
|
||||
raise ValueError(f"Directory not found: {full_search_dir}")
|
||||
|
||||
csv_files = []
|
||||
sqlite_files = []
|
||||
for root, _, files in os.walk(full_search_dir):
|
||||
for file in files:
|
||||
if file.lower().endswith(".csv"):
|
||||
rel_path = os.path.relpath(root, start=base_dir)
|
||||
csv_files.append(
|
||||
{
|
||||
"full_path": os.path.join(root, file),
|
||||
"rel_path": rel_path,
|
||||
"file_name": file,
|
||||
}
|
||||
)
|
||||
return csv_files
|
||||
if file.lower().endswith(".sqlite"):
|
||||
full_path = os.path.join(root, file)
|
||||
sqlite_files.append(full_path)
|
||||
return sqlite_files
|
||||
|
||||
async def _import_to_database(self, csv_files: List[Dict]) -> Dict:
|
||||
"""Import CSV data to SQLite"""
|
||||
async def _merge_sqlite_databases(self, sqlite_files: List[str]) -> Dict:
|
||||
"""Merge multiple SQLite files into the main database"""
|
||||
await self.init_connector()
|
||||
assert self._connector is not None
|
||||
results = {
|
||||
"total_files": len(csv_files),
|
||||
"successful": 0,
|
||||
"failed": 0,
|
||||
"tables_created": [],
|
||||
}
|
||||
|
||||
def _process_one_file(file_info: Dict) -> Tuple[bool, Optional[str]]:
|
||||
table_name = ""
|
||||
try:
|
||||
path_parts = [p for p in file_info["rel_path"].split(os.sep) if p]
|
||||
table_name = "_".join(path_parts + [Path(file_info["file_name"]).stem])
|
||||
table_name = self._sanitize_table_name(table_name)
|
||||
|
||||
with self._connector.session_scope() as session:
|
||||
session.execute(text(f'DROP TABLE IF EXISTS "{table_name}"'))
|
||||
session.commit()
|
||||
encodings = ["utf-8-sig", "utf-8", "latin-1", "iso-8859-1", "cp1252"]
|
||||
|
||||
for encoding in encodings:
|
||||
try:
|
||||
with open(file_info["full_path"], "r", encoding=encoding) as f:
|
||||
content = f.read()
|
||||
|
||||
if not content.strip():
|
||||
raise ValueError("File is empty")
|
||||
|
||||
content = content.replace("\r\n", "\n").replace("\r", "\n")
|
||||
lines = [line for line in content.split("\n") if line.strip()]
|
||||
if not lines:
|
||||
raise ValueError("No data after normalization")
|
||||
|
||||
header_line = lines[0]
|
||||
data_line = lines[1] if len(lines) > 1 else ""
|
||||
|
||||
try:
|
||||
sample_for_sniff = "\n".join(lines[:10])
|
||||
sniffer = csv.Sniffer()
|
||||
try:
|
||||
dialect = sniffer.sniff(sample_for_sniff)
|
||||
except Exception:
|
||||
# Fallback: choose delimiter by counting common
|
||||
# separators in header/data line
|
||||
delims = [",", "\t", ";", "|"]
|
||||
counts = {
|
||||
d: (header_line.count(d) if header_line else 0)
|
||||
+ (data_line.count(d) if data_line else 0)
|
||||
for d in delims
|
||||
}
|
||||
best = (
|
||||
max(counts, key=counts.get)
|
||||
if any(counts.values())
|
||||
else ","
|
||||
)
|
||||
|
||||
class _DefaultDialect(csv.Dialect):
|
||||
delimiter = best
|
||||
quotechar = '"'
|
||||
doublequote = True
|
||||
skipinitialspace = False
|
||||
lineterminator = "\n"
|
||||
quoting = csv.QUOTE_MINIMAL
|
||||
|
||||
dialect = _DefaultDialect()
|
||||
|
||||
try:
|
||||
has_header = sniffer.has_header("\n".join(lines[:50]))
|
||||
except Exception:
|
||||
has_header = True
|
||||
|
||||
header_row = (
|
||||
list(csv.reader([header_line], dialect))[0]
|
||||
if header_line
|
||||
else []
|
||||
)
|
||||
first_data_row = (
|
||||
list(csv.reader([data_line], dialect))[0]
|
||||
if data_line
|
||||
else []
|
||||
)
|
||||
|
||||
# Heuristic: if has_header is False but header_row looks
|
||||
# like names (mostly alphabetic), treat as header
|
||||
if not has_header:
|
||||
|
||||
def _looks_like_header(tokens: List[str]) -> bool:
|
||||
if not tokens:
|
||||
return False
|
||||
# 非空、重复少、字母比例高
|
||||
cleaned = [
|
||||
str(t).strip() for t in tokens if str(t).strip()
|
||||
]
|
||||
if not cleaned:
|
||||
return False
|
||||
# 允许少量数字,但大多以字母开头
|
||||
alpha_starts = sum(
|
||||
1
|
||||
for t in cleaned
|
||||
if t and (t[0].isalpha() or t[0] == "_")
|
||||
)
|
||||
return alpha_starts >= max(
|
||||
1, int(0.6 * len(cleaned))
|
||||
)
|
||||
|
||||
if _looks_like_header(header_row):
|
||||
has_header = True
|
||||
|
||||
if not has_header:
|
||||
num_cols_guess = len(header_row)
|
||||
headers = [f"col_{i}" for i in range(num_cols_guess)]
|
||||
first_data_row = header_row
|
||||
else:
|
||||
headers = header_row
|
||||
|
||||
num_cols = (
|
||||
len(first_data_row) if first_data_row else len(headers)
|
||||
)
|
||||
|
||||
# no header
|
||||
if not headers or all(
|
||||
(not str(h).strip()) for h in headers
|
||||
):
|
||||
headers = [f"col_{i}" for i in range(num_cols or 1)]
|
||||
|
||||
headers = self._sanitize_and_dedup_headers(headers)
|
||||
|
||||
if num_cols <= 0:
|
||||
num_cols = len(headers)
|
||||
headers = headers[:num_cols]
|
||||
if not headers or any(
|
||||
h is None or h == "" for h in headers
|
||||
):
|
||||
raise csv.Error("Invalid headers after sanitization")
|
||||
|
||||
create_sql = f'''
|
||||
CREATE TABLE IF NOT EXISTS "{table_name}" (
|
||||
{", ".join([f'"{h}" TEXT' for h in headers])}
|
||||
)
|
||||
'''
|
||||
insert_sql = f'''
|
||||
INSERT INTO "{table_name}" ({
|
||||
", ".join([f'"{h}"' for h in headers])
|
||||
})
|
||||
VALUES ({
|
||||
", ".join([":" + f"p{i}" for i in range(len(headers))])
|
||||
})
|
||||
'''
|
||||
|
||||
with self._connector.session_scope() as session:
|
||||
logger.debug(
|
||||
f"Table: {table_name}, headers(final): {headers}"
|
||||
)
|
||||
session.execute(text(create_sql))
|
||||
|
||||
reader = csv.reader(lines, dialect)
|
||||
if has_header:
|
||||
next(reader, None)
|
||||
|
||||
batch_params: List[Dict[str, Any]] = []
|
||||
for row in reader:
|
||||
if not row:
|
||||
continue
|
||||
if len(row) != len(headers):
|
||||
if len(row) < len(headers):
|
||||
row += [None] * (len(headers) - len(row))
|
||||
else:
|
||||
row = row[: len(headers)]
|
||||
params = {
|
||||
f"p{i}": (row[i] if i < len(row) else None)
|
||||
for i in range(len(headers))
|
||||
}
|
||||
batch_params.append(params)
|
||||
if len(batch_params) >= 1000:
|
||||
session.execute(text(insert_sql), batch_params)
|
||||
batch_params = []
|
||||
if batch_params:
|
||||
session.execute(text(insert_sql), batch_params)
|
||||
session.commit()
|
||||
|
||||
return True, table_name
|
||||
|
||||
except csv.Error:
|
||||
self._import_with_simple_split_blocking(table_name, content)
|
||||
return True, table_name
|
||||
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning(f"Error with encoding {encoding}: {str(e)}")
|
||||
continue
|
||||
def _worker():
|
||||
results = {
|
||||
"total_files": len(sqlite_files),
|
||||
"successful": 0,
|
||||
"failed": 0,
|
||||
"tables_merged": [],
|
||||
}
|
||||
|
||||
with self._connector.session_scope() as session:
|
||||
# 获取底层的 sqlite3 连接对象
|
||||
connection_proxy = session.connection()
|
||||
# 兼容不同版本的 SQLAlchemy 获取底层连接的方式
|
||||
try:
|
||||
with open(file_info["full_path"], "rb") as f:
|
||||
content = f.read().decode("ascii", errors="ignore")
|
||||
if content.strip():
|
||||
self._import_with_simple_split_blocking(table_name, content)
|
||||
return True, table_name
|
||||
else:
|
||||
raise ValueError("File is empty or unreadable")
|
||||
except Exception as e:
|
||||
return (
|
||||
False,
|
||||
f"Failed to process {file_info['file_name']}: {str(e)}",
|
||||
)
|
||||
# SQLAlchemy 1.4+ / 2.0
|
||||
raw_conn = connection_proxy.connection.dbapi_connection
|
||||
except AttributeError:
|
||||
try:
|
||||
# 旧版本或某些驱动
|
||||
raw_conn = connection_proxy.connection
|
||||
except AttributeError:
|
||||
# 最后的尝试
|
||||
raw_conn = session.get_bind().raw_connection()
|
||||
|
||||
except Exception as e:
|
||||
return (
|
||||
False,
|
||||
f"Failed to process {file_info.get('full_path', '')}: {str(e)}",
|
||||
)
|
||||
# 确保 raw_conn 是 sqlite3 的连接对象
|
||||
if not raw_conn:
|
||||
raise RuntimeError("Failed to get raw sqlite3 connection")
|
||||
|
||||
for file_info in csv_files:
|
||||
ok, info = await self._run_in_thread(_process_one_file, file_info)
|
||||
if ok:
|
||||
results["successful"] += 1
|
||||
if info:
|
||||
results["tables_created"].append(info)
|
||||
else:
|
||||
results["failed"] += 1
|
||||
logger.error(info)
|
||||
cursor = raw_conn.cursor()
|
||||
|
||||
return results
|
||||
for db_path in sqlite_files:
|
||||
src_alias = f"src_db_{uuid.uuid4().hex[:8]}"
|
||||
try:
|
||||
try:
|
||||
cursor.execute("PRAGMA database_list")
|
||||
attached_dbs = cursor.fetchall()
|
||||
for _, name, _ in attached_dbs:
|
||||
if name not in ("main", "temp"):
|
||||
cursor.execute(f"DETACH DATABASE {name}")
|
||||
except Exception as cleanup_err:
|
||||
logger.warning(f"Cleanup warning: {cleanup_err}")
|
||||
|
||||
cursor.execute(f"ATTACH DATABASE ? AS {src_alias}", (db_path,))
|
||||
|
||||
cursor.execute(
|
||||
f"SELECT name, sql FROM {src_alias}.sqlite_master "
|
||||
f"WHERE type='table' AND name NOT LIKE 'sqlite_%'"
|
||||
)
|
||||
tables = cursor.fetchall()
|
||||
|
||||
for table_name, create_sql in tables:
|
||||
cursor.execute(
|
||||
"SELECT name FROM sqlite_master "
|
||||
"WHERE type='table' "
|
||||
"AND name=?",
|
||||
(table_name,),
|
||||
)
|
||||
if not cursor.fetchone():
|
||||
cursor.execute(create_sql)
|
||||
cursor.execute(
|
||||
f'INSERT INTO main."{table_name}" '
|
||||
f'SELECT * FROM {src_alias}."{table_name}"'
|
||||
)
|
||||
results["tables_merged"].append(table_name)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Table '{table_name}' exists. Skipping."
|
||||
)
|
||||
|
||||
raw_conn.commit()
|
||||
results["successful"] += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to merge {db_path}: {e}")
|
||||
results["failed"] += 1
|
||||
try:
|
||||
raw_conn.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
try:
|
||||
cursor.execute(f"DETACH DATABASE {src_alias}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return results
|
||||
|
||||
return await self._run_in_thread(_worker)
|
||||
|
||||
def _import_with_simple_split_blocking(self, table_name: str, content: str):
|
||||
"""Fallback method for malformed CSV files (blocking, 使用 SQLAlchemy 执行)"""
|
||||
|
||||
Reference in New Issue
Block a user