opt: code format

This commit is contained in:
yaoyifan-yyf
2025-09-25 13:37:57 +08:00
parent 9e50bff12c
commit bc14084826
2 changed files with 143 additions and 99 deletions

View File

@@ -36,7 +36,9 @@ from dbgpt_app.base import (
from dbgpt_app.component_configs import initialize_components
from dbgpt_app.config import ApplicationConfig, ServiceWebParameters, SystemParameters
from dbgpt_serve.core import add_exception_handler
from dbgpt_serve.evaluate.service.fetchdata.benchmark_data_manager import BenchmarkDataManager, get_benchmark_manager
from dbgpt_serve.evaluate.service.fetchdata.benchmark_data_manager import (
get_benchmark_manager,
)
logger = logging.getLogger(__name__)
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
@@ -153,8 +155,6 @@ def initialize_app(param: ApplicationConfig, args: List[str] = None):
else:
loop.run_until_complete(load_benchmark_data())
binding_port = web_config.port
binding_host = web_config.host
if not web_config.light:
@@ -341,8 +341,7 @@ async def load_benchmark_data():
async with manager:
logger.info("Fetching data from GitHub repository...")
result = await manager.load_from_github(
repo_url="https://github.com/inclusionAI/Falcon",
data_dir="data/source"
repo_url="https://github.com/inclusionAI/Falcon", data_dir="data/source"
)
# Log detailed results
@@ -351,22 +350,20 @@ async def load_benchmark_data():
logger.info(f"Successfully imported: {result['successful']}")
logger.info(f"Failed imports: {result['failed']}")
if result['failed'] > 0:
if result["failed"] > 0:
logger.warning(f"Encountered {result['failed']} failures during import")
# Verify the loaded data
table_info = await manager.get_table_info()
logger.info(f"Loaded {len(table_info)} tables into database")
return {
'import_result': result,
'table_info': table_info
}
return {"import_result": result, "table_info": table_info}
except Exception as e:
logger.error("Failed to load benchmark data", exc_info=True)
raise RuntimeError(f"Benchmark data loading failed: {str(e)}") from e
if __name__ == "__main__":
# Parse command line arguments
_args = parse_args()

View File

@@ -1,20 +1,21 @@
import os
import csv
import sqlite3
import aiohttp
import asyncio
import zipfile
import tempfile
import shutil
import time
import csv
import hashlib
from pathlib import Path
from typing import List, Dict, Optional, Type
import logging
import json
import logging
import os
import shutil
import sqlite3
import tempfile
import time
import zipfile
from pathlib import Path
from typing import Dict, List, Optional
import aiohttp
from dbgpt.component import BaseComponent, ComponentType, SystemApp
from dbgpt._private.pydantic import BaseModel, ConfigDict
from dbgpt.component import BaseComponent, ComponentType, SystemApp
logger = logging.getLogger(__name__)
@@ -35,7 +36,9 @@ class BenchmarkDataManager(BaseComponent):
name = ComponentType.BENCHMARK_DATA_MANAGER
def __init__(self, system_app: SystemApp, config: Optional[BenchmarkDataConfig] = None):
def __init__(
self, system_app: SystemApp, config: Optional[BenchmarkDataConfig] = None
):
super().__init__(system_app)
self._config = config or BenchmarkDataConfig()
self._http_session = None
@@ -83,7 +86,9 @@ class BenchmarkDataManager(BaseComponent):
columns = [col[0] for col in cursor.description]
return [dict(zip(columns, row)) for row in cursor.fetchall()]
async def load_from_github(self, repo_url: str, data_dir: str = "data/source") -> Dict:
async def load_from_github(
self, repo_url: str, data_dir: str = "data/source"
) -> Dict:
"""Main method to load data from GitHub repository"""
try:
# 1. Download or use cached repository
@@ -97,7 +102,6 @@ class BenchmarkDataManager(BaseComponent):
# 3. Import to SQLite
result = await self._import_to_database(csv_files)
logger.info(f"Import completed: {result['successful']} succeeded, {result['failed']} failed")
return result
except Exception as e:
@@ -123,8 +127,8 @@ class BenchmarkDataManager(BaseComponent):
columns = cursor.fetchall()
result[table_name] = {
'row_count': row_count,
'columns': [{'name': col[1], 'type': col[2]} for col in columns]
"row_count": row_count,
"columns": [{"name": col[1], "type": col[2]} for col in columns],
}
return result
@@ -144,15 +148,19 @@ class BenchmarkDataManager(BaseComponent):
def _load_mappings(self) -> Dict[str, str]:
"""Load table name mappings from config file"""
if not self._config.table_mapping_file or not os.path.exists(self._config.table_mapping_file):
logger.warning(f"Table mapping file not found: {self._config.table_mapping_file}")
if not self._config.table_mapping_file or not os.path.exists(
self._config.table_mapping_file
):
logger.warning(
f"Table mapping file not found: {self._config.table_mapping_file}"
)
return {}
try:
with open(self._config.table_mapping_file, 'r', encoding='utf-8') as f:
with open(self._config.table_mapping_file, "r", encoding="utf-8") as f:
mapping = json.load(f)
return {
key: value.split('.')[-1] if '.' in value else value
key: value.split(".")[-1] if "." in value else value
for key, value in mapping.items()
}
except Exception as e:
@@ -164,11 +172,28 @@ class BenchmarkDataManager(BaseComponent):
mapped_name = self._table_mappings.get(name.lower(), name)
# Clean special characters
invalid_chars = ['-', ' ', '.', ',', ';', ':', '!', '?', "'", '"', '(', ')', '[', ']', '{', '}']
invalid_chars = [
"-",
" ",
".",
",",
";",
":",
"!",
"?",
"'",
'"',
"(",
")",
"[",
"]",
"{",
"}",
]
for char in invalid_chars:
mapped_name = mapped_name.replace(char, '_')
while '__' in mapped_name:
mapped_name = mapped_name.replace('__', '_')
mapped_name = mapped_name.replace(char, "_")
while "__" in mapped_name:
mapped_name = mapped_name.replace("__", "_")
return mapped_name.lower()
@@ -183,15 +208,17 @@ class BenchmarkDataManager(BaseComponent):
# Download fresh copy
self.temp_dir = tempfile.mkdtemp()
zip_url = repo_url.replace('github.com', 'api.github.com/repos') + '/zipball/main'
zip_url = (
repo_url.replace("github.com", "api.github.com/repos") + "/zipball/main"
)
logger.info(f"Downloading from GitHub repo: {zip_url}")
try:
async with self._http_session.get(zip_url) as response:
response.raise_for_status()
zip_path = os.path.join(self.temp_dir, 'repo.zip')
zip_path = os.path.join(self.temp_dir, "repo.zip")
with open(zip_path, 'wb') as f:
with open(zip_path, "wb") as f:
while True:
chunk = await response.content.read(1024)
if not chunk:
@@ -210,7 +237,7 @@ class BenchmarkDataManager(BaseComponent):
def _get_cache_path(self, repo_url: str) -> str:
"""Get path to cached zip file"""
cache_key = hashlib.md5(repo_url.encode('utf-8')).hexdigest()
cache_key = hashlib.md5(repo_url.encode("utf-8")).hexdigest()
return os.path.join(self._config.cache_dir, f"{cache_key}.zip")
def _is_cache_valid(self, cache_path: str) -> bool:
@@ -227,11 +254,14 @@ class BenchmarkDataManager(BaseComponent):
def _extract_zip(self, zip_path: str) -> str:
"""Extract zip to temp directory"""
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(self.temp_dir)
extracted_dirs = [d for d in os.listdir(self.temp_dir)
if os.path.isdir(os.path.join(self.temp_dir, d))]
extracted_dirs = [
d
for d in os.listdir(self.temp_dir)
if os.path.isdir(os.path.join(self.temp_dir, d))
]
if not extracted_dirs:
raise ValueError("No valid directory found after extraction")
return os.path.join(self.temp_dir, extracted_dirs[0])
@@ -245,13 +275,15 @@ class BenchmarkDataManager(BaseComponent):
csv_files = []
for root, _, files in os.walk(full_search_dir):
for file in files:
if file.lower().endswith('.csv'):
if file.lower().endswith(".csv"):
rel_path = os.path.relpath(root, start=base_dir)
csv_files.append({
'full_path': os.path.join(root, file),
'rel_path': rel_path,
'file_name': file
})
csv_files.append(
{
"full_path": os.path.join(root, file),
"rel_path": rel_path,
"file_name": file,
}
)
return csv_files
async def _import_to_database(self, csv_files: List[Dict]) -> Dict:
@@ -259,24 +291,24 @@ class BenchmarkDataManager(BaseComponent):
conn = await self.get_connection()
cursor = conn.cursor()
results = {
'total_files': len(csv_files),
'successful': 0,
'failed': 0,
'tables_created': []
"total_files": len(csv_files),
"successful": 0,
"failed": 0,
"tables_created": [],
}
for file_info in csv_files:
try:
path_parts = [p for p in file_info['rel_path'].split(os.sep) if p]
table_name = '_'.join(path_parts + [Path(file_info['file_name']).stem])
path_parts = [p for p in file_info["rel_path"].split(os.sep) if p]
table_name = "_".join(path_parts + [Path(file_info["file_name"]).stem])
table_name = self._sanitize_table_name(table_name)
# Try multiple encodings
encodings = ['utf-8-sig', 'utf-8', 'latin-1', 'iso-8859-1', 'cp1252']
encodings = ["utf-8-sig", "utf-8", "latin-1", "iso-8859-1", "cp1252"]
for encoding in encodings:
try:
with open(file_info['full_path'], 'r', encoding=encoding) as f:
with open(file_info["full_path"], "r", encoding=encoding) as f:
content = f.read()
# Handle empty files
@@ -284,12 +316,13 @@ class BenchmarkDataManager(BaseComponent):
raise ValueError("File is empty")
# Replace problematic line breaks if needed
content = content.replace('\r\n', '\n').replace('\r', '\n')
content = content.replace("\r\n", "\n").replace("\r", "\n")
# Split into lines
lines = [line for line in content.split('\n') if line.strip()]
lines = [
line for line in content.split("\n") if line.strip()
]
# Parse header and first data line to detect actual structure
try:
header_line = lines[0]
data_line = lines[1] if len(lines) > 1 else ""
@@ -300,27 +333,46 @@ class BenchmarkDataManager(BaseComponent):
has_header = sniffer.has_header(content[:1024])
if has_header:
headers = list(csv.reader([header_line], dialect))[0]
first_data_row = list(csv.reader([data_line], dialect))[0] if data_line else []
headers = list(csv.reader([header_line], dialect))[
0
]
first_data_row = (
list(csv.reader([data_line], dialect))[0]
if data_line
else []
)
else:
headers = list(csv.reader([header_line], dialect))[0]
headers = list(csv.reader([header_line], dialect))[
0
]
first_data_row = headers # first line is data
headers = [f'col_{i}' for i in range(len(headers))]
headers = [f"col_{i}" for i in range(len(headers))]
# Determine actual number of columns from data
actual_columns = len(first_data_row) if first_data_row else len(headers)
actual_columns = (
len(first_data_row)
if first_data_row
else len(headers)
)
# Create table with correct number of columns
create_sql = f"""
CREATE TABLE IF NOT EXISTS {table_name} (
{', '.join([f'"{h}" TEXT' for h in headers[:actual_columns]])}
)
CREATE TABLE IF NOT EXISTS {table_name} ({
", ".join(
[
f'"{h}" TEXT'
for h in headers[:actual_columns]
]
)
})
"""
cursor.execute(create_sql)
# Prepare insert statement
insert_sql = f"""
INSERT INTO {table_name} VALUES ({', '.join(['?'] * actual_columns)})
INSERT INTO {table_name} VALUES ({
", ".join(["?"] * actual_columns)
})
"""
# Process data
@@ -335,9 +387,6 @@ class BenchmarkDataManager(BaseComponent):
# Ensure row has correct number of columns
if len(row) != actual_columns:
logger.warning(
f"Adjusting row with {len(row)} columns to match {actual_columns} columns"
)
if len(row) < actual_columns:
row += [None] * (actual_columns - len(row))
else:
@@ -351,16 +400,12 @@ class BenchmarkDataManager(BaseComponent):
if batch:
cursor.executemany(insert_sql, batch)
results['successful'] += 1
results['tables_created'].append(table_name)
logger.info(
f"Imported: {file_info['rel_path']}/{file_info['file_name']} -> {table_name}"
)
results["successful"] += 1
results["tables_created"].append(table_name)
break
except csv.Error as e:
# Fallback for malformed CSV files
logger.warning(f"CSV parsing error ({encoding}): {str(e)}. Trying simple split")
self._import_with_simple_split(
cursor, table_name, content, results, file_info
)
@@ -374,8 +419,8 @@ class BenchmarkDataManager(BaseComponent):
else:
# All encodings failed - try binary mode as last resort
try:
with open(file_info['full_path'], 'rb') as f:
content = f.read().decode('ascii', errors='ignore')
with open(file_info["full_path"], "rb") as f:
content = f.read().decode("ascii", errors="ignore")
if content.strip():
self._import_with_simple_split(
cursor, table_name, content, results, file_info
@@ -383,28 +428,32 @@ class BenchmarkDataManager(BaseComponent):
else:
raise ValueError("File is empty or unreadable")
except Exception as e:
results['failed'] += 1
logger.error(f"Failed to process {file_info['file_name']}: {str(e)}")
results["failed"] += 1
logger.error(
f"Failed to process {file_info['file_name']}: {str(e)}"
)
except Exception as e:
results['failed'] += 1
results["failed"] += 1
logger.error(f"Failed to process {file_info['full_path']}: {str(e)}")
self._db_conn.commit()
return results
def _import_with_simple_split(self, cursor, table_name, content, results, file_info):
def _import_with_simple_split(
self, cursor, table_name, content, results, file_info
):
"""Fallback method for malformed CSV files"""
# Normalize line endings
content = content.replace('\r\n', '\n').replace('\r', '\n')
lines = [line for line in content.split('\n') if line.strip()]
content = content.replace("\r\n", "\n").replace("\r", "\n")
lines = [line for line in content.split("\n") if line.strip()]
if not lines:
raise ValueError("No data found after cleaning")
# Determine delimiter
first_line = lines[0]
delimiter = ',' if ',' in first_line else '\t' if '\t' in first_line else ';'
delimiter = "," if "," in first_line else "\t" if "\t" in first_line else ";"
# Process header
headers = first_line.split(delimiter)
@@ -413,14 +462,14 @@ class BenchmarkDataManager(BaseComponent):
# Create table
create_sql = f"""
CREATE TABLE IF NOT EXISTS {table_name} (
{', '.join([f'col_{i} TEXT' for i in range(actual_columns)])}
{", ".join([f"col_{i} TEXT" for i in range(actual_columns)])}
)
"""
cursor.execute(create_sql)
# Prepare insert
insert_sql = f"""
INSERT INTO {table_name} VALUES ({', '.join(['?'] * actual_columns)})
INSERT INTO {table_name} VALUES ({", ".join(["?"] * actual_columns)})
"""
# Process data
@@ -428,9 +477,6 @@ class BenchmarkDataManager(BaseComponent):
for line in lines[1:]: # skip header
row = line.split(delimiter)
if len(row) != actual_columns:
logger.warning(
f"Adjusting row with {len(row)} columns to match {actual_columns} columns"
)
if len(row) < actual_columns:
row += [None] * (actual_columns - len(row))
else:
@@ -444,11 +490,8 @@ class BenchmarkDataManager(BaseComponent):
if batch:
cursor.executemany(insert_sql, batch)
results['successful'] += 1
results['tables_created'].append(table_name)
logger.info(
f"Imported (simple split): {file_info['rel_path']}/{file_info['file_name']} -> {table_name}"
)
results["successful"] += 1
results["tables_created"].append(table_name)
def _cleanup_temp_dir(self):
"""Clean up temporary directory"""
@@ -463,7 +506,9 @@ class BenchmarkDataManager(BaseComponent):
_SYSTEM_APP: Optional[SystemApp] = None
def initialize_benchmark_data(system_app: SystemApp, config: Optional[BenchmarkDataConfig] = None):
def initialize_benchmark_data(
system_app: SystemApp, config: Optional[BenchmarkDataConfig] = None
):
"""Initialize benchmark data manager component"""
global _SYSTEM_APP
_SYSTEM_APP = system_app
@@ -472,10 +517,12 @@ def initialize_benchmark_data(system_app: SystemApp, config: Optional[BenchmarkD
return manager
def get_benchmark_manager(system_app: Optional[SystemApp] = None) -> BenchmarkDataManager:
def get_benchmark_manager(
system_app: Optional[SystemApp] = None,
) -> BenchmarkDataManager:
"""Get the benchmark data manager instance"""
if not _SYSTEM_APP:
if not system_app:
system_app = SystemApp()
initialize_benchmark_data(system_app)
return BenchmarkDataManager.get_instance(system_app)
return BenchmarkDataManager.get_instance(system_app)