mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-26 20:13:40 +00:00
opt: code format
This commit is contained in:
@@ -36,7 +36,9 @@ from dbgpt_app.base import (
|
||||
from dbgpt_app.component_configs import initialize_components
|
||||
from dbgpt_app.config import ApplicationConfig, ServiceWebParameters, SystemParameters
|
||||
from dbgpt_serve.core import add_exception_handler
|
||||
from dbgpt_serve.evaluate.service.fetchdata.benchmark_data_manager import BenchmarkDataManager, get_benchmark_manager
|
||||
from dbgpt_serve.evaluate.service.fetchdata.benchmark_data_manager import (
|
||||
get_benchmark_manager,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
@@ -153,8 +155,6 @@ def initialize_app(param: ApplicationConfig, args: List[str] = None):
|
||||
else:
|
||||
loop.run_until_complete(load_benchmark_data())
|
||||
|
||||
|
||||
|
||||
binding_port = web_config.port
|
||||
binding_host = web_config.host
|
||||
if not web_config.light:
|
||||
@@ -341,8 +341,7 @@ async def load_benchmark_data():
|
||||
async with manager:
|
||||
logger.info("Fetching data from GitHub repository...")
|
||||
result = await manager.load_from_github(
|
||||
repo_url="https://github.com/inclusionAI/Falcon",
|
||||
data_dir="data/source"
|
||||
repo_url="https://github.com/inclusionAI/Falcon", data_dir="data/source"
|
||||
)
|
||||
|
||||
# Log detailed results
|
||||
@@ -351,22 +350,20 @@ async def load_benchmark_data():
|
||||
logger.info(f"Successfully imported: {result['successful']}")
|
||||
logger.info(f"Failed imports: {result['failed']}")
|
||||
|
||||
if result['failed'] > 0:
|
||||
if result["failed"] > 0:
|
||||
logger.warning(f"Encountered {result['failed']} failures during import")
|
||||
|
||||
# Verify the loaded data
|
||||
table_info = await manager.get_table_info()
|
||||
logger.info(f"Loaded {len(table_info)} tables into database")
|
||||
|
||||
return {
|
||||
'import_result': result,
|
||||
'table_info': table_info
|
||||
}
|
||||
return {"import_result": result, "table_info": table_info}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to load benchmark data", exc_info=True)
|
||||
raise RuntimeError(f"Benchmark data loading failed: {str(e)}") from e
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Parse command line arguments
|
||||
_args = parse_args()
|
||||
|
@@ -1,20 +1,21 @@
|
||||
import os
|
||||
import csv
|
||||
import sqlite3
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import zipfile
|
||||
import tempfile
|
||||
import shutil
|
||||
import time
|
||||
import csv
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional, Type
|
||||
import logging
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import sqlite3
|
||||
import tempfile
|
||||
import time
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import aiohttp
|
||||
|
||||
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
||||
from dbgpt._private.pydantic import BaseModel, ConfigDict
|
||||
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -35,7 +36,9 @@ class BenchmarkDataManager(BaseComponent):
|
||||
|
||||
name = ComponentType.BENCHMARK_DATA_MANAGER
|
||||
|
||||
def __init__(self, system_app: SystemApp, config: Optional[BenchmarkDataConfig] = None):
|
||||
def __init__(
|
||||
self, system_app: SystemApp, config: Optional[BenchmarkDataConfig] = None
|
||||
):
|
||||
super().__init__(system_app)
|
||||
self._config = config or BenchmarkDataConfig()
|
||||
self._http_session = None
|
||||
@@ -83,7 +86,9 @@ class BenchmarkDataManager(BaseComponent):
|
||||
columns = [col[0] for col in cursor.description]
|
||||
return [dict(zip(columns, row)) for row in cursor.fetchall()]
|
||||
|
||||
async def load_from_github(self, repo_url: str, data_dir: str = "data/source") -> Dict:
|
||||
async def load_from_github(
|
||||
self, repo_url: str, data_dir: str = "data/source"
|
||||
) -> Dict:
|
||||
"""Main method to load data from GitHub repository"""
|
||||
try:
|
||||
# 1. Download or use cached repository
|
||||
@@ -97,7 +102,6 @@ class BenchmarkDataManager(BaseComponent):
|
||||
|
||||
# 3. Import to SQLite
|
||||
result = await self._import_to_database(csv_files)
|
||||
logger.info(f"Import completed: {result['successful']} succeeded, {result['failed']} failed")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
@@ -123,8 +127,8 @@ class BenchmarkDataManager(BaseComponent):
|
||||
columns = cursor.fetchall()
|
||||
|
||||
result[table_name] = {
|
||||
'row_count': row_count,
|
||||
'columns': [{'name': col[1], 'type': col[2]} for col in columns]
|
||||
"row_count": row_count,
|
||||
"columns": [{"name": col[1], "type": col[2]} for col in columns],
|
||||
}
|
||||
return result
|
||||
|
||||
@@ -144,15 +148,19 @@ class BenchmarkDataManager(BaseComponent):
|
||||
|
||||
def _load_mappings(self) -> Dict[str, str]:
|
||||
"""Load table name mappings from config file"""
|
||||
if not self._config.table_mapping_file or not os.path.exists(self._config.table_mapping_file):
|
||||
logger.warning(f"Table mapping file not found: {self._config.table_mapping_file}")
|
||||
if not self._config.table_mapping_file or not os.path.exists(
|
||||
self._config.table_mapping_file
|
||||
):
|
||||
logger.warning(
|
||||
f"Table mapping file not found: {self._config.table_mapping_file}"
|
||||
)
|
||||
return {}
|
||||
|
||||
try:
|
||||
with open(self._config.table_mapping_file, 'r', encoding='utf-8') as f:
|
||||
with open(self._config.table_mapping_file, "r", encoding="utf-8") as f:
|
||||
mapping = json.load(f)
|
||||
return {
|
||||
key: value.split('.')[-1] if '.' in value else value
|
||||
key: value.split(".")[-1] if "." in value else value
|
||||
for key, value in mapping.items()
|
||||
}
|
||||
except Exception as e:
|
||||
@@ -164,11 +172,28 @@ class BenchmarkDataManager(BaseComponent):
|
||||
mapped_name = self._table_mappings.get(name.lower(), name)
|
||||
|
||||
# Clean special characters
|
||||
invalid_chars = ['-', ' ', '.', ',', ';', ':', '!', '?', "'", '"', '(', ')', '[', ']', '{', '}']
|
||||
invalid_chars = [
|
||||
"-",
|
||||
" ",
|
||||
".",
|
||||
",",
|
||||
";",
|
||||
":",
|
||||
"!",
|
||||
"?",
|
||||
"'",
|
||||
'"',
|
||||
"(",
|
||||
")",
|
||||
"[",
|
||||
"]",
|
||||
"{",
|
||||
"}",
|
||||
]
|
||||
for char in invalid_chars:
|
||||
mapped_name = mapped_name.replace(char, '_')
|
||||
while '__' in mapped_name:
|
||||
mapped_name = mapped_name.replace('__', '_')
|
||||
mapped_name = mapped_name.replace(char, "_")
|
||||
while "__" in mapped_name:
|
||||
mapped_name = mapped_name.replace("__", "_")
|
||||
|
||||
return mapped_name.lower()
|
||||
|
||||
@@ -183,15 +208,17 @@ class BenchmarkDataManager(BaseComponent):
|
||||
|
||||
# Download fresh copy
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
zip_url = repo_url.replace('github.com', 'api.github.com/repos') + '/zipball/main'
|
||||
zip_url = (
|
||||
repo_url.replace("github.com", "api.github.com/repos") + "/zipball/main"
|
||||
)
|
||||
logger.info(f"Downloading from GitHub repo: {zip_url}")
|
||||
|
||||
try:
|
||||
async with self._http_session.get(zip_url) as response:
|
||||
response.raise_for_status()
|
||||
zip_path = os.path.join(self.temp_dir, 'repo.zip')
|
||||
zip_path = os.path.join(self.temp_dir, "repo.zip")
|
||||
|
||||
with open(zip_path, 'wb') as f:
|
||||
with open(zip_path, "wb") as f:
|
||||
while True:
|
||||
chunk = await response.content.read(1024)
|
||||
if not chunk:
|
||||
@@ -210,7 +237,7 @@ class BenchmarkDataManager(BaseComponent):
|
||||
|
||||
def _get_cache_path(self, repo_url: str) -> str:
|
||||
"""Get path to cached zip file"""
|
||||
cache_key = hashlib.md5(repo_url.encode('utf-8')).hexdigest()
|
||||
cache_key = hashlib.md5(repo_url.encode("utf-8")).hexdigest()
|
||||
return os.path.join(self._config.cache_dir, f"{cache_key}.zip")
|
||||
|
||||
def _is_cache_valid(self, cache_path: str) -> bool:
|
||||
@@ -227,11 +254,14 @@ class BenchmarkDataManager(BaseComponent):
|
||||
|
||||
def _extract_zip(self, zip_path: str) -> str:
|
||||
"""Extract zip to temp directory"""
|
||||
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
||||
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
||||
zip_ref.extractall(self.temp_dir)
|
||||
|
||||
extracted_dirs = [d for d in os.listdir(self.temp_dir)
|
||||
if os.path.isdir(os.path.join(self.temp_dir, d))]
|
||||
extracted_dirs = [
|
||||
d
|
||||
for d in os.listdir(self.temp_dir)
|
||||
if os.path.isdir(os.path.join(self.temp_dir, d))
|
||||
]
|
||||
if not extracted_dirs:
|
||||
raise ValueError("No valid directory found after extraction")
|
||||
return os.path.join(self.temp_dir, extracted_dirs[0])
|
||||
@@ -245,13 +275,15 @@ class BenchmarkDataManager(BaseComponent):
|
||||
csv_files = []
|
||||
for root, _, files in os.walk(full_search_dir):
|
||||
for file in files:
|
||||
if file.lower().endswith('.csv'):
|
||||
if file.lower().endswith(".csv"):
|
||||
rel_path = os.path.relpath(root, start=base_dir)
|
||||
csv_files.append({
|
||||
'full_path': os.path.join(root, file),
|
||||
'rel_path': rel_path,
|
||||
'file_name': file
|
||||
})
|
||||
csv_files.append(
|
||||
{
|
||||
"full_path": os.path.join(root, file),
|
||||
"rel_path": rel_path,
|
||||
"file_name": file,
|
||||
}
|
||||
)
|
||||
return csv_files
|
||||
|
||||
async def _import_to_database(self, csv_files: List[Dict]) -> Dict:
|
||||
@@ -259,24 +291,24 @@ class BenchmarkDataManager(BaseComponent):
|
||||
conn = await self.get_connection()
|
||||
cursor = conn.cursor()
|
||||
results = {
|
||||
'total_files': len(csv_files),
|
||||
'successful': 0,
|
||||
'failed': 0,
|
||||
'tables_created': []
|
||||
"total_files": len(csv_files),
|
||||
"successful": 0,
|
||||
"failed": 0,
|
||||
"tables_created": [],
|
||||
}
|
||||
|
||||
for file_info in csv_files:
|
||||
try:
|
||||
path_parts = [p for p in file_info['rel_path'].split(os.sep) if p]
|
||||
table_name = '_'.join(path_parts + [Path(file_info['file_name']).stem])
|
||||
path_parts = [p for p in file_info["rel_path"].split(os.sep) if p]
|
||||
table_name = "_".join(path_parts + [Path(file_info["file_name"]).stem])
|
||||
table_name = self._sanitize_table_name(table_name)
|
||||
|
||||
# Try multiple encodings
|
||||
encodings = ['utf-8-sig', 'utf-8', 'latin-1', 'iso-8859-1', 'cp1252']
|
||||
encodings = ["utf-8-sig", "utf-8", "latin-1", "iso-8859-1", "cp1252"]
|
||||
|
||||
for encoding in encodings:
|
||||
try:
|
||||
with open(file_info['full_path'], 'r', encoding=encoding) as f:
|
||||
with open(file_info["full_path"], "r", encoding=encoding) as f:
|
||||
content = f.read()
|
||||
|
||||
# Handle empty files
|
||||
@@ -284,12 +316,13 @@ class BenchmarkDataManager(BaseComponent):
|
||||
raise ValueError("File is empty")
|
||||
|
||||
# Replace problematic line breaks if needed
|
||||
content = content.replace('\r\n', '\n').replace('\r', '\n')
|
||||
content = content.replace("\r\n", "\n").replace("\r", "\n")
|
||||
|
||||
# Split into lines
|
||||
lines = [line for line in content.split('\n') if line.strip()]
|
||||
lines = [
|
||||
line for line in content.split("\n") if line.strip()
|
||||
]
|
||||
|
||||
# Parse header and first data line to detect actual structure
|
||||
try:
|
||||
header_line = lines[0]
|
||||
data_line = lines[1] if len(lines) > 1 else ""
|
||||
@@ -300,27 +333,46 @@ class BenchmarkDataManager(BaseComponent):
|
||||
has_header = sniffer.has_header(content[:1024])
|
||||
|
||||
if has_header:
|
||||
headers = list(csv.reader([header_line], dialect))[0]
|
||||
first_data_row = list(csv.reader([data_line], dialect))[0] if data_line else []
|
||||
headers = list(csv.reader([header_line], dialect))[
|
||||
0
|
||||
]
|
||||
first_data_row = (
|
||||
list(csv.reader([data_line], dialect))[0]
|
||||
if data_line
|
||||
else []
|
||||
)
|
||||
else:
|
||||
headers = list(csv.reader([header_line], dialect))[0]
|
||||
headers = list(csv.reader([header_line], dialect))[
|
||||
0
|
||||
]
|
||||
first_data_row = headers # first line is data
|
||||
headers = [f'col_{i}' for i in range(len(headers))]
|
||||
headers = [f"col_{i}" for i in range(len(headers))]
|
||||
|
||||
# Determine actual number of columns from data
|
||||
actual_columns = len(first_data_row) if first_data_row else len(headers)
|
||||
actual_columns = (
|
||||
len(first_data_row)
|
||||
if first_data_row
|
||||
else len(headers)
|
||||
)
|
||||
|
||||
# Create table with correct number of columns
|
||||
create_sql = f"""
|
||||
CREATE TABLE IF NOT EXISTS {table_name} (
|
||||
{', '.join([f'"{h}" TEXT' for h in headers[:actual_columns]])}
|
||||
)
|
||||
CREATE TABLE IF NOT EXISTS {table_name} ({
|
||||
", ".join(
|
||||
[
|
||||
f'"{h}" TEXT'
|
||||
for h in headers[:actual_columns]
|
||||
]
|
||||
)
|
||||
})
|
||||
"""
|
||||
cursor.execute(create_sql)
|
||||
|
||||
# Prepare insert statement
|
||||
insert_sql = f"""
|
||||
INSERT INTO {table_name} VALUES ({', '.join(['?'] * actual_columns)})
|
||||
INSERT INTO {table_name} VALUES ({
|
||||
", ".join(["?"] * actual_columns)
|
||||
})
|
||||
"""
|
||||
|
||||
# Process data
|
||||
@@ -335,9 +387,6 @@ class BenchmarkDataManager(BaseComponent):
|
||||
|
||||
# Ensure row has correct number of columns
|
||||
if len(row) != actual_columns:
|
||||
logger.warning(
|
||||
f"Adjusting row with {len(row)} columns to match {actual_columns} columns"
|
||||
)
|
||||
if len(row) < actual_columns:
|
||||
row += [None] * (actual_columns - len(row))
|
||||
else:
|
||||
@@ -351,16 +400,12 @@ class BenchmarkDataManager(BaseComponent):
|
||||
if batch:
|
||||
cursor.executemany(insert_sql, batch)
|
||||
|
||||
results['successful'] += 1
|
||||
results['tables_created'].append(table_name)
|
||||
logger.info(
|
||||
f"Imported: {file_info['rel_path']}/{file_info['file_name']} -> {table_name}"
|
||||
)
|
||||
results["successful"] += 1
|
||||
results["tables_created"].append(table_name)
|
||||
break
|
||||
|
||||
except csv.Error as e:
|
||||
# Fallback for malformed CSV files
|
||||
logger.warning(f"CSV parsing error ({encoding}): {str(e)}. Trying simple split")
|
||||
self._import_with_simple_split(
|
||||
cursor, table_name, content, results, file_info
|
||||
)
|
||||
@@ -374,8 +419,8 @@ class BenchmarkDataManager(BaseComponent):
|
||||
else:
|
||||
# All encodings failed - try binary mode as last resort
|
||||
try:
|
||||
with open(file_info['full_path'], 'rb') as f:
|
||||
content = f.read().decode('ascii', errors='ignore')
|
||||
with open(file_info["full_path"], "rb") as f:
|
||||
content = f.read().decode("ascii", errors="ignore")
|
||||
if content.strip():
|
||||
self._import_with_simple_split(
|
||||
cursor, table_name, content, results, file_info
|
||||
@@ -383,28 +428,32 @@ class BenchmarkDataManager(BaseComponent):
|
||||
else:
|
||||
raise ValueError("File is empty or unreadable")
|
||||
except Exception as e:
|
||||
results['failed'] += 1
|
||||
logger.error(f"Failed to process {file_info['file_name']}: {str(e)}")
|
||||
results["failed"] += 1
|
||||
logger.error(
|
||||
f"Failed to process {file_info['file_name']}: {str(e)}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
results['failed'] += 1
|
||||
results["failed"] += 1
|
||||
logger.error(f"Failed to process {file_info['full_path']}: {str(e)}")
|
||||
|
||||
self._db_conn.commit()
|
||||
return results
|
||||
|
||||
def _import_with_simple_split(self, cursor, table_name, content, results, file_info):
|
||||
def _import_with_simple_split(
|
||||
self, cursor, table_name, content, results, file_info
|
||||
):
|
||||
"""Fallback method for malformed CSV files"""
|
||||
# Normalize line endings
|
||||
content = content.replace('\r\n', '\n').replace('\r', '\n')
|
||||
lines = [line for line in content.split('\n') if line.strip()]
|
||||
content = content.replace("\r\n", "\n").replace("\r", "\n")
|
||||
lines = [line for line in content.split("\n") if line.strip()]
|
||||
|
||||
if not lines:
|
||||
raise ValueError("No data found after cleaning")
|
||||
|
||||
# Determine delimiter
|
||||
first_line = lines[0]
|
||||
delimiter = ',' if ',' in first_line else '\t' if '\t' in first_line else ';'
|
||||
delimiter = "," if "," in first_line else "\t" if "\t" in first_line else ";"
|
||||
|
||||
# Process header
|
||||
headers = first_line.split(delimiter)
|
||||
@@ -413,14 +462,14 @@ class BenchmarkDataManager(BaseComponent):
|
||||
# Create table
|
||||
create_sql = f"""
|
||||
CREATE TABLE IF NOT EXISTS {table_name} (
|
||||
{', '.join([f'col_{i} TEXT' for i in range(actual_columns)])}
|
||||
{", ".join([f"col_{i} TEXT" for i in range(actual_columns)])}
|
||||
)
|
||||
"""
|
||||
cursor.execute(create_sql)
|
||||
|
||||
# Prepare insert
|
||||
insert_sql = f"""
|
||||
INSERT INTO {table_name} VALUES ({', '.join(['?'] * actual_columns)})
|
||||
INSERT INTO {table_name} VALUES ({", ".join(["?"] * actual_columns)})
|
||||
"""
|
||||
|
||||
# Process data
|
||||
@@ -428,9 +477,6 @@ class BenchmarkDataManager(BaseComponent):
|
||||
for line in lines[1:]: # skip header
|
||||
row = line.split(delimiter)
|
||||
if len(row) != actual_columns:
|
||||
logger.warning(
|
||||
f"Adjusting row with {len(row)} columns to match {actual_columns} columns"
|
||||
)
|
||||
if len(row) < actual_columns:
|
||||
row += [None] * (actual_columns - len(row))
|
||||
else:
|
||||
@@ -444,11 +490,8 @@ class BenchmarkDataManager(BaseComponent):
|
||||
if batch:
|
||||
cursor.executemany(insert_sql, batch)
|
||||
|
||||
results['successful'] += 1
|
||||
results['tables_created'].append(table_name)
|
||||
logger.info(
|
||||
f"Imported (simple split): {file_info['rel_path']}/{file_info['file_name']} -> {table_name}"
|
||||
)
|
||||
results["successful"] += 1
|
||||
results["tables_created"].append(table_name)
|
||||
|
||||
def _cleanup_temp_dir(self):
|
||||
"""Clean up temporary directory"""
|
||||
@@ -463,7 +506,9 @@ class BenchmarkDataManager(BaseComponent):
|
||||
_SYSTEM_APP: Optional[SystemApp] = None
|
||||
|
||||
|
||||
def initialize_benchmark_data(system_app: SystemApp, config: Optional[BenchmarkDataConfig] = None):
|
||||
def initialize_benchmark_data(
|
||||
system_app: SystemApp, config: Optional[BenchmarkDataConfig] = None
|
||||
):
|
||||
"""Initialize benchmark data manager component"""
|
||||
global _SYSTEM_APP
|
||||
_SYSTEM_APP = system_app
|
||||
@@ -472,10 +517,12 @@ def initialize_benchmark_data(system_app: SystemApp, config: Optional[BenchmarkD
|
||||
return manager
|
||||
|
||||
|
||||
def get_benchmark_manager(system_app: Optional[SystemApp] = None) -> BenchmarkDataManager:
|
||||
def get_benchmark_manager(
|
||||
system_app: Optional[SystemApp] = None,
|
||||
) -> BenchmarkDataManager:
|
||||
"""Get the benchmark data manager instance"""
|
||||
if not _SYSTEM_APP:
|
||||
if not system_app:
|
||||
system_app = SystemApp()
|
||||
initialize_benchmark_data(system_app)
|
||||
return BenchmarkDataManager.get_instance(system_app)
|
||||
return BenchmarkDataManager.get_instance(system_app)
|
||||
|
Reference in New Issue
Block a user