opt: code format

This commit is contained in:
yaoyifan-yyf
2025-09-25 13:37:57 +08:00
parent 9e50bff12c
commit bc14084826
2 changed files with 143 additions and 99 deletions

View File

@@ -36,7 +36,9 @@ from dbgpt_app.base import (
from dbgpt_app.component_configs import initialize_components from dbgpt_app.component_configs import initialize_components
from dbgpt_app.config import ApplicationConfig, ServiceWebParameters, SystemParameters from dbgpt_app.config import ApplicationConfig, ServiceWebParameters, SystemParameters
from dbgpt_serve.core import add_exception_handler from dbgpt_serve.core import add_exception_handler
from dbgpt_serve.evaluate.service.fetchdata.benchmark_data_manager import BenchmarkDataManager, get_benchmark_manager from dbgpt_serve.evaluate.service.fetchdata.benchmark_data_manager import (
get_benchmark_manager,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
@@ -153,8 +155,6 @@ def initialize_app(param: ApplicationConfig, args: List[str] = None):
else: else:
loop.run_until_complete(load_benchmark_data()) loop.run_until_complete(load_benchmark_data())
binding_port = web_config.port binding_port = web_config.port
binding_host = web_config.host binding_host = web_config.host
if not web_config.light: if not web_config.light:
@@ -341,8 +341,7 @@ async def load_benchmark_data():
async with manager: async with manager:
logger.info("Fetching data from GitHub repository...") logger.info("Fetching data from GitHub repository...")
result = await manager.load_from_github( result = await manager.load_from_github(
repo_url="https://github.com/inclusionAI/Falcon", repo_url="https://github.com/inclusionAI/Falcon", data_dir="data/source"
data_dir="data/source"
) )
# Log detailed results # Log detailed results
@@ -351,22 +350,20 @@ async def load_benchmark_data():
logger.info(f"Successfully imported: {result['successful']}") logger.info(f"Successfully imported: {result['successful']}")
logger.info(f"Failed imports: {result['failed']}") logger.info(f"Failed imports: {result['failed']}")
if result['failed'] > 0: if result["failed"] > 0:
logger.warning(f"Encountered {result['failed']} failures during import") logger.warning(f"Encountered {result['failed']} failures during import")
# Verify the loaded data # Verify the loaded data
table_info = await manager.get_table_info() table_info = await manager.get_table_info()
logger.info(f"Loaded {len(table_info)} tables into database") logger.info(f"Loaded {len(table_info)} tables into database")
return { return {"import_result": result, "table_info": table_info}
'import_result': result,
'table_info': table_info
}
except Exception as e: except Exception as e:
logger.error("Failed to load benchmark data", exc_info=True) logger.error("Failed to load benchmark data", exc_info=True)
raise RuntimeError(f"Benchmark data loading failed: {str(e)}") from e raise RuntimeError(f"Benchmark data loading failed: {str(e)}") from e
if __name__ == "__main__": if __name__ == "__main__":
# Parse command line arguments # Parse command line arguments
_args = parse_args() _args = parse_args()

View File

@@ -1,20 +1,21 @@
import os
import csv
import sqlite3
import aiohttp
import asyncio import asyncio
import zipfile import csv
import tempfile
import shutil
import time
import hashlib import hashlib
from pathlib import Path
from typing import List, Dict, Optional, Type
import logging
import json import json
import logging
import os
import shutil
import sqlite3
import tempfile
import time
import zipfile
from pathlib import Path
from typing import Dict, List, Optional
import aiohttp
from dbgpt.component import BaseComponent, ComponentType, SystemApp
from dbgpt._private.pydantic import BaseModel, ConfigDict from dbgpt._private.pydantic import BaseModel, ConfigDict
from dbgpt.component import BaseComponent, ComponentType, SystemApp
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -35,7 +36,9 @@ class BenchmarkDataManager(BaseComponent):
name = ComponentType.BENCHMARK_DATA_MANAGER name = ComponentType.BENCHMARK_DATA_MANAGER
def __init__(self, system_app: SystemApp, config: Optional[BenchmarkDataConfig] = None): def __init__(
self, system_app: SystemApp, config: Optional[BenchmarkDataConfig] = None
):
super().__init__(system_app) super().__init__(system_app)
self._config = config or BenchmarkDataConfig() self._config = config or BenchmarkDataConfig()
self._http_session = None self._http_session = None
@@ -83,7 +86,9 @@ class BenchmarkDataManager(BaseComponent):
columns = [col[0] for col in cursor.description] columns = [col[0] for col in cursor.description]
return [dict(zip(columns, row)) for row in cursor.fetchall()] return [dict(zip(columns, row)) for row in cursor.fetchall()]
async def load_from_github(self, repo_url: str, data_dir: str = "data/source") -> Dict: async def load_from_github(
self, repo_url: str, data_dir: str = "data/source"
) -> Dict:
"""Main method to load data from GitHub repository""" """Main method to load data from GitHub repository"""
try: try:
# 1. Download or use cached repository # 1. Download or use cached repository
@@ -97,7 +102,6 @@ class BenchmarkDataManager(BaseComponent):
# 3. Import to SQLite # 3. Import to SQLite
result = await self._import_to_database(csv_files) result = await self._import_to_database(csv_files)
logger.info(f"Import completed: {result['successful']} succeeded, {result['failed']} failed")
return result return result
except Exception as e: except Exception as e:
@@ -123,8 +127,8 @@ class BenchmarkDataManager(BaseComponent):
columns = cursor.fetchall() columns = cursor.fetchall()
result[table_name] = { result[table_name] = {
'row_count': row_count, "row_count": row_count,
'columns': [{'name': col[1], 'type': col[2]} for col in columns] "columns": [{"name": col[1], "type": col[2]} for col in columns],
} }
return result return result
@@ -144,15 +148,19 @@ class BenchmarkDataManager(BaseComponent):
def _load_mappings(self) -> Dict[str, str]: def _load_mappings(self) -> Dict[str, str]:
"""Load table name mappings from config file""" """Load table name mappings from config file"""
if not self._config.table_mapping_file or not os.path.exists(self._config.table_mapping_file): if not self._config.table_mapping_file or not os.path.exists(
logger.warning(f"Table mapping file not found: {self._config.table_mapping_file}") self._config.table_mapping_file
):
logger.warning(
f"Table mapping file not found: {self._config.table_mapping_file}"
)
return {} return {}
try: try:
with open(self._config.table_mapping_file, 'r', encoding='utf-8') as f: with open(self._config.table_mapping_file, "r", encoding="utf-8") as f:
mapping = json.load(f) mapping = json.load(f)
return { return {
key: value.split('.')[-1] if '.' in value else value key: value.split(".")[-1] if "." in value else value
for key, value in mapping.items() for key, value in mapping.items()
} }
except Exception as e: except Exception as e:
@@ -164,11 +172,28 @@ class BenchmarkDataManager(BaseComponent):
mapped_name = self._table_mappings.get(name.lower(), name) mapped_name = self._table_mappings.get(name.lower(), name)
# Clean special characters # Clean special characters
invalid_chars = ['-', ' ', '.', ',', ';', ':', '!', '?', "'", '"', '(', ')', '[', ']', '{', '}'] invalid_chars = [
"-",
" ",
".",
",",
";",
":",
"!",
"?",
"'",
'"',
"(",
")",
"[",
"]",
"{",
"}",
]
for char in invalid_chars: for char in invalid_chars:
mapped_name = mapped_name.replace(char, '_') mapped_name = mapped_name.replace(char, "_")
while '__' in mapped_name: while "__" in mapped_name:
mapped_name = mapped_name.replace('__', '_') mapped_name = mapped_name.replace("__", "_")
return mapped_name.lower() return mapped_name.lower()
@@ -183,15 +208,17 @@ class BenchmarkDataManager(BaseComponent):
# Download fresh copy # Download fresh copy
self.temp_dir = tempfile.mkdtemp() self.temp_dir = tempfile.mkdtemp()
zip_url = repo_url.replace('github.com', 'api.github.com/repos') + '/zipball/main' zip_url = (
repo_url.replace("github.com", "api.github.com/repos") + "/zipball/main"
)
logger.info(f"Downloading from GitHub repo: {zip_url}") logger.info(f"Downloading from GitHub repo: {zip_url}")
try: try:
async with self._http_session.get(zip_url) as response: async with self._http_session.get(zip_url) as response:
response.raise_for_status() response.raise_for_status()
zip_path = os.path.join(self.temp_dir, 'repo.zip') zip_path = os.path.join(self.temp_dir, "repo.zip")
with open(zip_path, 'wb') as f: with open(zip_path, "wb") as f:
while True: while True:
chunk = await response.content.read(1024) chunk = await response.content.read(1024)
if not chunk: if not chunk:
@@ -210,7 +237,7 @@ class BenchmarkDataManager(BaseComponent):
def _get_cache_path(self, repo_url: str) -> str: def _get_cache_path(self, repo_url: str) -> str:
"""Get path to cached zip file""" """Get path to cached zip file"""
cache_key = hashlib.md5(repo_url.encode('utf-8')).hexdigest() cache_key = hashlib.md5(repo_url.encode("utf-8")).hexdigest()
return os.path.join(self._config.cache_dir, f"{cache_key}.zip") return os.path.join(self._config.cache_dir, f"{cache_key}.zip")
def _is_cache_valid(self, cache_path: str) -> bool: def _is_cache_valid(self, cache_path: str) -> bool:
@@ -227,11 +254,14 @@ class BenchmarkDataManager(BaseComponent):
def _extract_zip(self, zip_path: str) -> str: def _extract_zip(self, zip_path: str) -> str:
"""Extract zip to temp directory""" """Extract zip to temp directory"""
with zipfile.ZipFile(zip_path, 'r') as zip_ref: with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(self.temp_dir) zip_ref.extractall(self.temp_dir)
extracted_dirs = [d for d in os.listdir(self.temp_dir) extracted_dirs = [
if os.path.isdir(os.path.join(self.temp_dir, d))] d
for d in os.listdir(self.temp_dir)
if os.path.isdir(os.path.join(self.temp_dir, d))
]
if not extracted_dirs: if not extracted_dirs:
raise ValueError("No valid directory found after extraction") raise ValueError("No valid directory found after extraction")
return os.path.join(self.temp_dir, extracted_dirs[0]) return os.path.join(self.temp_dir, extracted_dirs[0])
@@ -245,13 +275,15 @@ class BenchmarkDataManager(BaseComponent):
csv_files = [] csv_files = []
for root, _, files in os.walk(full_search_dir): for root, _, files in os.walk(full_search_dir):
for file in files: for file in files:
if file.lower().endswith('.csv'): if file.lower().endswith(".csv"):
rel_path = os.path.relpath(root, start=base_dir) rel_path = os.path.relpath(root, start=base_dir)
csv_files.append({ csv_files.append(
'full_path': os.path.join(root, file), {
'rel_path': rel_path, "full_path": os.path.join(root, file),
'file_name': file "rel_path": rel_path,
}) "file_name": file,
}
)
return csv_files return csv_files
async def _import_to_database(self, csv_files: List[Dict]) -> Dict: async def _import_to_database(self, csv_files: List[Dict]) -> Dict:
@@ -259,24 +291,24 @@ class BenchmarkDataManager(BaseComponent):
conn = await self.get_connection() conn = await self.get_connection()
cursor = conn.cursor() cursor = conn.cursor()
results = { results = {
'total_files': len(csv_files), "total_files": len(csv_files),
'successful': 0, "successful": 0,
'failed': 0, "failed": 0,
'tables_created': [] "tables_created": [],
} }
for file_info in csv_files: for file_info in csv_files:
try: try:
path_parts = [p for p in file_info['rel_path'].split(os.sep) if p] path_parts = [p for p in file_info["rel_path"].split(os.sep) if p]
table_name = '_'.join(path_parts + [Path(file_info['file_name']).stem]) table_name = "_".join(path_parts + [Path(file_info["file_name"]).stem])
table_name = self._sanitize_table_name(table_name) table_name = self._sanitize_table_name(table_name)
# Try multiple encodings # Try multiple encodings
encodings = ['utf-8-sig', 'utf-8', 'latin-1', 'iso-8859-1', 'cp1252'] encodings = ["utf-8-sig", "utf-8", "latin-1", "iso-8859-1", "cp1252"]
for encoding in encodings: for encoding in encodings:
try: try:
with open(file_info['full_path'], 'r', encoding=encoding) as f: with open(file_info["full_path"], "r", encoding=encoding) as f:
content = f.read() content = f.read()
# Handle empty files # Handle empty files
@@ -284,12 +316,13 @@ class BenchmarkDataManager(BaseComponent):
raise ValueError("File is empty") raise ValueError("File is empty")
# Replace problematic line breaks if needed # Replace problematic line breaks if needed
content = content.replace('\r\n', '\n').replace('\r', '\n') content = content.replace("\r\n", "\n").replace("\r", "\n")
# Split into lines # Split into lines
lines = [line for line in content.split('\n') if line.strip()] lines = [
line for line in content.split("\n") if line.strip()
]
# Parse header and first data line to detect actual structure
try: try:
header_line = lines[0] header_line = lines[0]
data_line = lines[1] if len(lines) > 1 else "" data_line = lines[1] if len(lines) > 1 else ""
@@ -300,27 +333,46 @@ class BenchmarkDataManager(BaseComponent):
has_header = sniffer.has_header(content[:1024]) has_header = sniffer.has_header(content[:1024])
if has_header: if has_header:
headers = list(csv.reader([header_line], dialect))[0] headers = list(csv.reader([header_line], dialect))[
first_data_row = list(csv.reader([data_line], dialect))[0] if data_line else [] 0
]
first_data_row = (
list(csv.reader([data_line], dialect))[0]
if data_line
else []
)
else: else:
headers = list(csv.reader([header_line], dialect))[0] headers = list(csv.reader([header_line], dialect))[
0
]
first_data_row = headers # first line is data first_data_row = headers # first line is data
headers = [f'col_{i}' for i in range(len(headers))] headers = [f"col_{i}" for i in range(len(headers))]
# Determine actual number of columns from data # Determine actual number of columns from data
actual_columns = len(first_data_row) if first_data_row else len(headers) actual_columns = (
len(first_data_row)
if first_data_row
else len(headers)
)
# Create table with correct number of columns # Create table with correct number of columns
create_sql = f""" create_sql = f"""
CREATE TABLE IF NOT EXISTS {table_name} ( CREATE TABLE IF NOT EXISTS {table_name} ({
{', '.join([f'"{h}" TEXT' for h in headers[:actual_columns]])} ", ".join(
) [
f'"{h}" TEXT'
for h in headers[:actual_columns]
]
)
})
""" """
cursor.execute(create_sql) cursor.execute(create_sql)
# Prepare insert statement # Prepare insert statement
insert_sql = f""" insert_sql = f"""
INSERT INTO {table_name} VALUES ({', '.join(['?'] * actual_columns)}) INSERT INTO {table_name} VALUES ({
", ".join(["?"] * actual_columns)
})
""" """
# Process data # Process data
@@ -335,9 +387,6 @@ class BenchmarkDataManager(BaseComponent):
# Ensure row has correct number of columns # Ensure row has correct number of columns
if len(row) != actual_columns: if len(row) != actual_columns:
logger.warning(
f"Adjusting row with {len(row)} columns to match {actual_columns} columns"
)
if len(row) < actual_columns: if len(row) < actual_columns:
row += [None] * (actual_columns - len(row)) row += [None] * (actual_columns - len(row))
else: else:
@@ -351,16 +400,12 @@ class BenchmarkDataManager(BaseComponent):
if batch: if batch:
cursor.executemany(insert_sql, batch) cursor.executemany(insert_sql, batch)
results['successful'] += 1 results["successful"] += 1
results['tables_created'].append(table_name) results["tables_created"].append(table_name)
logger.info(
f"Imported: {file_info['rel_path']}/{file_info['file_name']} -> {table_name}"
)
break break
except csv.Error as e: except csv.Error as e:
# Fallback for malformed CSV files # Fallback for malformed CSV files
logger.warning(f"CSV parsing error ({encoding}): {str(e)}. Trying simple split")
self._import_with_simple_split( self._import_with_simple_split(
cursor, table_name, content, results, file_info cursor, table_name, content, results, file_info
) )
@@ -374,8 +419,8 @@ class BenchmarkDataManager(BaseComponent):
else: else:
# All encodings failed - try binary mode as last resort # All encodings failed - try binary mode as last resort
try: try:
with open(file_info['full_path'], 'rb') as f: with open(file_info["full_path"], "rb") as f:
content = f.read().decode('ascii', errors='ignore') content = f.read().decode("ascii", errors="ignore")
if content.strip(): if content.strip():
self._import_with_simple_split( self._import_with_simple_split(
cursor, table_name, content, results, file_info cursor, table_name, content, results, file_info
@@ -383,28 +428,32 @@ class BenchmarkDataManager(BaseComponent):
else: else:
raise ValueError("File is empty or unreadable") raise ValueError("File is empty or unreadable")
except Exception as e: except Exception as e:
results['failed'] += 1 results["failed"] += 1
logger.error(f"Failed to process {file_info['file_name']}: {str(e)}") logger.error(
f"Failed to process {file_info['file_name']}: {str(e)}"
)
except Exception as e: except Exception as e:
results['failed'] += 1 results["failed"] += 1
logger.error(f"Failed to process {file_info['full_path']}: {str(e)}") logger.error(f"Failed to process {file_info['full_path']}: {str(e)}")
self._db_conn.commit() self._db_conn.commit()
return results return results
def _import_with_simple_split(self, cursor, table_name, content, results, file_info): def _import_with_simple_split(
self, cursor, table_name, content, results, file_info
):
"""Fallback method for malformed CSV files""" """Fallback method for malformed CSV files"""
# Normalize line endings # Normalize line endings
content = content.replace('\r\n', '\n').replace('\r', '\n') content = content.replace("\r\n", "\n").replace("\r", "\n")
lines = [line for line in content.split('\n') if line.strip()] lines = [line for line in content.split("\n") if line.strip()]
if not lines: if not lines:
raise ValueError("No data found after cleaning") raise ValueError("No data found after cleaning")
# Determine delimiter # Determine delimiter
first_line = lines[0] first_line = lines[0]
delimiter = ',' if ',' in first_line else '\t' if '\t' in first_line else ';' delimiter = "," if "," in first_line else "\t" if "\t" in first_line else ";"
# Process header # Process header
headers = first_line.split(delimiter) headers = first_line.split(delimiter)
@@ -413,14 +462,14 @@ class BenchmarkDataManager(BaseComponent):
# Create table # Create table
create_sql = f""" create_sql = f"""
CREATE TABLE IF NOT EXISTS {table_name} ( CREATE TABLE IF NOT EXISTS {table_name} (
{', '.join([f'col_{i} TEXT' for i in range(actual_columns)])} {", ".join([f"col_{i} TEXT" for i in range(actual_columns)])}
) )
""" """
cursor.execute(create_sql) cursor.execute(create_sql)
# Prepare insert # Prepare insert
insert_sql = f""" insert_sql = f"""
INSERT INTO {table_name} VALUES ({', '.join(['?'] * actual_columns)}) INSERT INTO {table_name} VALUES ({", ".join(["?"] * actual_columns)})
""" """
# Process data # Process data
@@ -428,9 +477,6 @@ class BenchmarkDataManager(BaseComponent):
for line in lines[1:]: # skip header for line in lines[1:]: # skip header
row = line.split(delimiter) row = line.split(delimiter)
if len(row) != actual_columns: if len(row) != actual_columns:
logger.warning(
f"Adjusting row with {len(row)} columns to match {actual_columns} columns"
)
if len(row) < actual_columns: if len(row) < actual_columns:
row += [None] * (actual_columns - len(row)) row += [None] * (actual_columns - len(row))
else: else:
@@ -444,11 +490,8 @@ class BenchmarkDataManager(BaseComponent):
if batch: if batch:
cursor.executemany(insert_sql, batch) cursor.executemany(insert_sql, batch)
results['successful'] += 1 results["successful"] += 1
results['tables_created'].append(table_name) results["tables_created"].append(table_name)
logger.info(
f"Imported (simple split): {file_info['rel_path']}/{file_info['file_name']} -> {table_name}"
)
def _cleanup_temp_dir(self): def _cleanup_temp_dir(self):
"""Clean up temporary directory""" """Clean up temporary directory"""
@@ -463,7 +506,9 @@ class BenchmarkDataManager(BaseComponent):
_SYSTEM_APP: Optional[SystemApp] = None _SYSTEM_APP: Optional[SystemApp] = None
def initialize_benchmark_data(system_app: SystemApp, config: Optional[BenchmarkDataConfig] = None): def initialize_benchmark_data(
system_app: SystemApp, config: Optional[BenchmarkDataConfig] = None
):
"""Initialize benchmark data manager component""" """Initialize benchmark data manager component"""
global _SYSTEM_APP global _SYSTEM_APP
_SYSTEM_APP = system_app _SYSTEM_APP = system_app
@@ -472,7 +517,9 @@ def initialize_benchmark_data(system_app: SystemApp, config: Optional[BenchmarkD
return manager return manager
def get_benchmark_manager(system_app: Optional[SystemApp] = None) -> BenchmarkDataManager: def get_benchmark_manager(
system_app: Optional[SystemApp] = None,
) -> BenchmarkDataManager:
"""Get the benchmark data manager instance""" """Get the benchmark data manager instance"""
if not _SYSTEM_APP: if not _SYSTEM_APP:
if not system_app: if not system_app: