mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-27 04:24:16 +00:00
opt: code format
This commit is contained in:
@@ -36,7 +36,9 @@ from dbgpt_app.base import (
|
|||||||
from dbgpt_app.component_configs import initialize_components
|
from dbgpt_app.component_configs import initialize_components
|
||||||
from dbgpt_app.config import ApplicationConfig, ServiceWebParameters, SystemParameters
|
from dbgpt_app.config import ApplicationConfig, ServiceWebParameters, SystemParameters
|
||||||
from dbgpt_serve.core import add_exception_handler
|
from dbgpt_serve.core import add_exception_handler
|
||||||
from dbgpt_serve.evaluate.service.fetchdata.benchmark_data_manager import BenchmarkDataManager, get_benchmark_manager
|
from dbgpt_serve.evaluate.service.fetchdata.benchmark_data_manager import (
|
||||||
|
get_benchmark_manager,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
@@ -153,8 +155,6 @@ def initialize_app(param: ApplicationConfig, args: List[str] = None):
|
|||||||
else:
|
else:
|
||||||
loop.run_until_complete(load_benchmark_data())
|
loop.run_until_complete(load_benchmark_data())
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
binding_port = web_config.port
|
binding_port = web_config.port
|
||||||
binding_host = web_config.host
|
binding_host = web_config.host
|
||||||
if not web_config.light:
|
if not web_config.light:
|
||||||
@@ -341,8 +341,7 @@ async def load_benchmark_data():
|
|||||||
async with manager:
|
async with manager:
|
||||||
logger.info("Fetching data from GitHub repository...")
|
logger.info("Fetching data from GitHub repository...")
|
||||||
result = await manager.load_from_github(
|
result = await manager.load_from_github(
|
||||||
repo_url="https://github.com/inclusionAI/Falcon",
|
repo_url="https://github.com/inclusionAI/Falcon", data_dir="data/source"
|
||||||
data_dir="data/source"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Log detailed results
|
# Log detailed results
|
||||||
@@ -351,22 +350,20 @@ async def load_benchmark_data():
|
|||||||
logger.info(f"Successfully imported: {result['successful']}")
|
logger.info(f"Successfully imported: {result['successful']}")
|
||||||
logger.info(f"Failed imports: {result['failed']}")
|
logger.info(f"Failed imports: {result['failed']}")
|
||||||
|
|
||||||
if result['failed'] > 0:
|
if result["failed"] > 0:
|
||||||
logger.warning(f"Encountered {result['failed']} failures during import")
|
logger.warning(f"Encountered {result['failed']} failures during import")
|
||||||
|
|
||||||
# Verify the loaded data
|
# Verify the loaded data
|
||||||
table_info = await manager.get_table_info()
|
table_info = await manager.get_table_info()
|
||||||
logger.info(f"Loaded {len(table_info)} tables into database")
|
logger.info(f"Loaded {len(table_info)} tables into database")
|
||||||
|
|
||||||
return {
|
return {"import_result": result, "table_info": table_info}
|
||||||
'import_result': result,
|
|
||||||
'table_info': table_info
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed to load benchmark data", exc_info=True)
|
logger.error("Failed to load benchmark data", exc_info=True)
|
||||||
raise RuntimeError(f"Benchmark data loading failed: {str(e)}") from e
|
raise RuntimeError(f"Benchmark data loading failed: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Parse command line arguments
|
# Parse command line arguments
|
||||||
_args = parse_args()
|
_args = parse_args()
|
||||||
|
@@ -1,20 +1,21 @@
|
|||||||
import os
|
|
||||||
import csv
|
|
||||||
import sqlite3
|
|
||||||
import aiohttp
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import zipfile
|
import csv
|
||||||
import tempfile
|
|
||||||
import shutil
|
|
||||||
import time
|
|
||||||
import hashlib
|
import hashlib
|
||||||
from pathlib import Path
|
|
||||||
from typing import List, Dict, Optional, Type
|
|
||||||
import logging
|
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import sqlite3
|
||||||
|
import tempfile
|
||||||
|
import time
|
||||||
|
import zipfile
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
|
||||||
from dbgpt._private.pydantic import BaseModel, ConfigDict
|
from dbgpt._private.pydantic import BaseModel, ConfigDict
|
||||||
|
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -35,7 +36,9 @@ class BenchmarkDataManager(BaseComponent):
|
|||||||
|
|
||||||
name = ComponentType.BENCHMARK_DATA_MANAGER
|
name = ComponentType.BENCHMARK_DATA_MANAGER
|
||||||
|
|
||||||
def __init__(self, system_app: SystemApp, config: Optional[BenchmarkDataConfig] = None):
|
def __init__(
|
||||||
|
self, system_app: SystemApp, config: Optional[BenchmarkDataConfig] = None
|
||||||
|
):
|
||||||
super().__init__(system_app)
|
super().__init__(system_app)
|
||||||
self._config = config or BenchmarkDataConfig()
|
self._config = config or BenchmarkDataConfig()
|
||||||
self._http_session = None
|
self._http_session = None
|
||||||
@@ -83,7 +86,9 @@ class BenchmarkDataManager(BaseComponent):
|
|||||||
columns = [col[0] for col in cursor.description]
|
columns = [col[0] for col in cursor.description]
|
||||||
return [dict(zip(columns, row)) for row in cursor.fetchall()]
|
return [dict(zip(columns, row)) for row in cursor.fetchall()]
|
||||||
|
|
||||||
async def load_from_github(self, repo_url: str, data_dir: str = "data/source") -> Dict:
|
async def load_from_github(
|
||||||
|
self, repo_url: str, data_dir: str = "data/source"
|
||||||
|
) -> Dict:
|
||||||
"""Main method to load data from GitHub repository"""
|
"""Main method to load data from GitHub repository"""
|
||||||
try:
|
try:
|
||||||
# 1. Download or use cached repository
|
# 1. Download or use cached repository
|
||||||
@@ -97,7 +102,6 @@ class BenchmarkDataManager(BaseComponent):
|
|||||||
|
|
||||||
# 3. Import to SQLite
|
# 3. Import to SQLite
|
||||||
result = await self._import_to_database(csv_files)
|
result = await self._import_to_database(csv_files)
|
||||||
logger.info(f"Import completed: {result['successful']} succeeded, {result['failed']} failed")
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -123,8 +127,8 @@ class BenchmarkDataManager(BaseComponent):
|
|||||||
columns = cursor.fetchall()
|
columns = cursor.fetchall()
|
||||||
|
|
||||||
result[table_name] = {
|
result[table_name] = {
|
||||||
'row_count': row_count,
|
"row_count": row_count,
|
||||||
'columns': [{'name': col[1], 'type': col[2]} for col in columns]
|
"columns": [{"name": col[1], "type": col[2]} for col in columns],
|
||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -144,15 +148,19 @@ class BenchmarkDataManager(BaseComponent):
|
|||||||
|
|
||||||
def _load_mappings(self) -> Dict[str, str]:
|
def _load_mappings(self) -> Dict[str, str]:
|
||||||
"""Load table name mappings from config file"""
|
"""Load table name mappings from config file"""
|
||||||
if not self._config.table_mapping_file or not os.path.exists(self._config.table_mapping_file):
|
if not self._config.table_mapping_file or not os.path.exists(
|
||||||
logger.warning(f"Table mapping file not found: {self._config.table_mapping_file}")
|
self._config.table_mapping_file
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
f"Table mapping file not found: {self._config.table_mapping_file}"
|
||||||
|
)
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(self._config.table_mapping_file, 'r', encoding='utf-8') as f:
|
with open(self._config.table_mapping_file, "r", encoding="utf-8") as f:
|
||||||
mapping = json.load(f)
|
mapping = json.load(f)
|
||||||
return {
|
return {
|
||||||
key: value.split('.')[-1] if '.' in value else value
|
key: value.split(".")[-1] if "." in value else value
|
||||||
for key, value in mapping.items()
|
for key, value in mapping.items()
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -164,11 +172,28 @@ class BenchmarkDataManager(BaseComponent):
|
|||||||
mapped_name = self._table_mappings.get(name.lower(), name)
|
mapped_name = self._table_mappings.get(name.lower(), name)
|
||||||
|
|
||||||
# Clean special characters
|
# Clean special characters
|
||||||
invalid_chars = ['-', ' ', '.', ',', ';', ':', '!', '?', "'", '"', '(', ')', '[', ']', '{', '}']
|
invalid_chars = [
|
||||||
|
"-",
|
||||||
|
" ",
|
||||||
|
".",
|
||||||
|
",",
|
||||||
|
";",
|
||||||
|
":",
|
||||||
|
"!",
|
||||||
|
"?",
|
||||||
|
"'",
|
||||||
|
'"',
|
||||||
|
"(",
|
||||||
|
")",
|
||||||
|
"[",
|
||||||
|
"]",
|
||||||
|
"{",
|
||||||
|
"}",
|
||||||
|
]
|
||||||
for char in invalid_chars:
|
for char in invalid_chars:
|
||||||
mapped_name = mapped_name.replace(char, '_')
|
mapped_name = mapped_name.replace(char, "_")
|
||||||
while '__' in mapped_name:
|
while "__" in mapped_name:
|
||||||
mapped_name = mapped_name.replace('__', '_')
|
mapped_name = mapped_name.replace("__", "_")
|
||||||
|
|
||||||
return mapped_name.lower()
|
return mapped_name.lower()
|
||||||
|
|
||||||
@@ -183,15 +208,17 @@ class BenchmarkDataManager(BaseComponent):
|
|||||||
|
|
||||||
# Download fresh copy
|
# Download fresh copy
|
||||||
self.temp_dir = tempfile.mkdtemp()
|
self.temp_dir = tempfile.mkdtemp()
|
||||||
zip_url = repo_url.replace('github.com', 'api.github.com/repos') + '/zipball/main'
|
zip_url = (
|
||||||
|
repo_url.replace("github.com", "api.github.com/repos") + "/zipball/main"
|
||||||
|
)
|
||||||
logger.info(f"Downloading from GitHub repo: {zip_url}")
|
logger.info(f"Downloading from GitHub repo: {zip_url}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with self._http_session.get(zip_url) as response:
|
async with self._http_session.get(zip_url) as response:
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
zip_path = os.path.join(self.temp_dir, 'repo.zip')
|
zip_path = os.path.join(self.temp_dir, "repo.zip")
|
||||||
|
|
||||||
with open(zip_path, 'wb') as f:
|
with open(zip_path, "wb") as f:
|
||||||
while True:
|
while True:
|
||||||
chunk = await response.content.read(1024)
|
chunk = await response.content.read(1024)
|
||||||
if not chunk:
|
if not chunk:
|
||||||
@@ -210,7 +237,7 @@ class BenchmarkDataManager(BaseComponent):
|
|||||||
|
|
||||||
def _get_cache_path(self, repo_url: str) -> str:
|
def _get_cache_path(self, repo_url: str) -> str:
|
||||||
"""Get path to cached zip file"""
|
"""Get path to cached zip file"""
|
||||||
cache_key = hashlib.md5(repo_url.encode('utf-8')).hexdigest()
|
cache_key = hashlib.md5(repo_url.encode("utf-8")).hexdigest()
|
||||||
return os.path.join(self._config.cache_dir, f"{cache_key}.zip")
|
return os.path.join(self._config.cache_dir, f"{cache_key}.zip")
|
||||||
|
|
||||||
def _is_cache_valid(self, cache_path: str) -> bool:
|
def _is_cache_valid(self, cache_path: str) -> bool:
|
||||||
@@ -227,11 +254,14 @@ class BenchmarkDataManager(BaseComponent):
|
|||||||
|
|
||||||
def _extract_zip(self, zip_path: str) -> str:
|
def _extract_zip(self, zip_path: str) -> str:
|
||||||
"""Extract zip to temp directory"""
|
"""Extract zip to temp directory"""
|
||||||
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
||||||
zip_ref.extractall(self.temp_dir)
|
zip_ref.extractall(self.temp_dir)
|
||||||
|
|
||||||
extracted_dirs = [d for d in os.listdir(self.temp_dir)
|
extracted_dirs = [
|
||||||
if os.path.isdir(os.path.join(self.temp_dir, d))]
|
d
|
||||||
|
for d in os.listdir(self.temp_dir)
|
||||||
|
if os.path.isdir(os.path.join(self.temp_dir, d))
|
||||||
|
]
|
||||||
if not extracted_dirs:
|
if not extracted_dirs:
|
||||||
raise ValueError("No valid directory found after extraction")
|
raise ValueError("No valid directory found after extraction")
|
||||||
return os.path.join(self.temp_dir, extracted_dirs[0])
|
return os.path.join(self.temp_dir, extracted_dirs[0])
|
||||||
@@ -245,13 +275,15 @@ class BenchmarkDataManager(BaseComponent):
|
|||||||
csv_files = []
|
csv_files = []
|
||||||
for root, _, files in os.walk(full_search_dir):
|
for root, _, files in os.walk(full_search_dir):
|
||||||
for file in files:
|
for file in files:
|
||||||
if file.lower().endswith('.csv'):
|
if file.lower().endswith(".csv"):
|
||||||
rel_path = os.path.relpath(root, start=base_dir)
|
rel_path = os.path.relpath(root, start=base_dir)
|
||||||
csv_files.append({
|
csv_files.append(
|
||||||
'full_path': os.path.join(root, file),
|
{
|
||||||
'rel_path': rel_path,
|
"full_path": os.path.join(root, file),
|
||||||
'file_name': file
|
"rel_path": rel_path,
|
||||||
})
|
"file_name": file,
|
||||||
|
}
|
||||||
|
)
|
||||||
return csv_files
|
return csv_files
|
||||||
|
|
||||||
async def _import_to_database(self, csv_files: List[Dict]) -> Dict:
|
async def _import_to_database(self, csv_files: List[Dict]) -> Dict:
|
||||||
@@ -259,24 +291,24 @@ class BenchmarkDataManager(BaseComponent):
|
|||||||
conn = await self.get_connection()
|
conn = await self.get_connection()
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
results = {
|
results = {
|
||||||
'total_files': len(csv_files),
|
"total_files": len(csv_files),
|
||||||
'successful': 0,
|
"successful": 0,
|
||||||
'failed': 0,
|
"failed": 0,
|
||||||
'tables_created': []
|
"tables_created": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
for file_info in csv_files:
|
for file_info in csv_files:
|
||||||
try:
|
try:
|
||||||
path_parts = [p for p in file_info['rel_path'].split(os.sep) if p]
|
path_parts = [p for p in file_info["rel_path"].split(os.sep) if p]
|
||||||
table_name = '_'.join(path_parts + [Path(file_info['file_name']).stem])
|
table_name = "_".join(path_parts + [Path(file_info["file_name"]).stem])
|
||||||
table_name = self._sanitize_table_name(table_name)
|
table_name = self._sanitize_table_name(table_name)
|
||||||
|
|
||||||
# Try multiple encodings
|
# Try multiple encodings
|
||||||
encodings = ['utf-8-sig', 'utf-8', 'latin-1', 'iso-8859-1', 'cp1252']
|
encodings = ["utf-8-sig", "utf-8", "latin-1", "iso-8859-1", "cp1252"]
|
||||||
|
|
||||||
for encoding in encodings:
|
for encoding in encodings:
|
||||||
try:
|
try:
|
||||||
with open(file_info['full_path'], 'r', encoding=encoding) as f:
|
with open(file_info["full_path"], "r", encoding=encoding) as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
|
|
||||||
# Handle empty files
|
# Handle empty files
|
||||||
@@ -284,12 +316,13 @@ class BenchmarkDataManager(BaseComponent):
|
|||||||
raise ValueError("File is empty")
|
raise ValueError("File is empty")
|
||||||
|
|
||||||
# Replace problematic line breaks if needed
|
# Replace problematic line breaks if needed
|
||||||
content = content.replace('\r\n', '\n').replace('\r', '\n')
|
content = content.replace("\r\n", "\n").replace("\r", "\n")
|
||||||
|
|
||||||
# Split into lines
|
# Split into lines
|
||||||
lines = [line for line in content.split('\n') if line.strip()]
|
lines = [
|
||||||
|
line for line in content.split("\n") if line.strip()
|
||||||
|
]
|
||||||
|
|
||||||
# Parse header and first data line to detect actual structure
|
|
||||||
try:
|
try:
|
||||||
header_line = lines[0]
|
header_line = lines[0]
|
||||||
data_line = lines[1] if len(lines) > 1 else ""
|
data_line = lines[1] if len(lines) > 1 else ""
|
||||||
@@ -300,27 +333,46 @@ class BenchmarkDataManager(BaseComponent):
|
|||||||
has_header = sniffer.has_header(content[:1024])
|
has_header = sniffer.has_header(content[:1024])
|
||||||
|
|
||||||
if has_header:
|
if has_header:
|
||||||
headers = list(csv.reader([header_line], dialect))[0]
|
headers = list(csv.reader([header_line], dialect))[
|
||||||
first_data_row = list(csv.reader([data_line], dialect))[0] if data_line else []
|
0
|
||||||
|
]
|
||||||
|
first_data_row = (
|
||||||
|
list(csv.reader([data_line], dialect))[0]
|
||||||
|
if data_line
|
||||||
|
else []
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
headers = list(csv.reader([header_line], dialect))[0]
|
headers = list(csv.reader([header_line], dialect))[
|
||||||
|
0
|
||||||
|
]
|
||||||
first_data_row = headers # first line is data
|
first_data_row = headers # first line is data
|
||||||
headers = [f'col_{i}' for i in range(len(headers))]
|
headers = [f"col_{i}" for i in range(len(headers))]
|
||||||
|
|
||||||
# Determine actual number of columns from data
|
# Determine actual number of columns from data
|
||||||
actual_columns = len(first_data_row) if first_data_row else len(headers)
|
actual_columns = (
|
||||||
|
len(first_data_row)
|
||||||
|
if first_data_row
|
||||||
|
else len(headers)
|
||||||
|
)
|
||||||
|
|
||||||
# Create table with correct number of columns
|
# Create table with correct number of columns
|
||||||
create_sql = f"""
|
create_sql = f"""
|
||||||
CREATE TABLE IF NOT EXISTS {table_name} (
|
CREATE TABLE IF NOT EXISTS {table_name} ({
|
||||||
{', '.join([f'"{h}" TEXT' for h in headers[:actual_columns]])}
|
", ".join(
|
||||||
)
|
[
|
||||||
|
f'"{h}" TEXT'
|
||||||
|
for h in headers[:actual_columns]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
})
|
||||||
"""
|
"""
|
||||||
cursor.execute(create_sql)
|
cursor.execute(create_sql)
|
||||||
|
|
||||||
# Prepare insert statement
|
# Prepare insert statement
|
||||||
insert_sql = f"""
|
insert_sql = f"""
|
||||||
INSERT INTO {table_name} VALUES ({', '.join(['?'] * actual_columns)})
|
INSERT INTO {table_name} VALUES ({
|
||||||
|
", ".join(["?"] * actual_columns)
|
||||||
|
})
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Process data
|
# Process data
|
||||||
@@ -335,9 +387,6 @@ class BenchmarkDataManager(BaseComponent):
|
|||||||
|
|
||||||
# Ensure row has correct number of columns
|
# Ensure row has correct number of columns
|
||||||
if len(row) != actual_columns:
|
if len(row) != actual_columns:
|
||||||
logger.warning(
|
|
||||||
f"Adjusting row with {len(row)} columns to match {actual_columns} columns"
|
|
||||||
)
|
|
||||||
if len(row) < actual_columns:
|
if len(row) < actual_columns:
|
||||||
row += [None] * (actual_columns - len(row))
|
row += [None] * (actual_columns - len(row))
|
||||||
else:
|
else:
|
||||||
@@ -351,16 +400,12 @@ class BenchmarkDataManager(BaseComponent):
|
|||||||
if batch:
|
if batch:
|
||||||
cursor.executemany(insert_sql, batch)
|
cursor.executemany(insert_sql, batch)
|
||||||
|
|
||||||
results['successful'] += 1
|
results["successful"] += 1
|
||||||
results['tables_created'].append(table_name)
|
results["tables_created"].append(table_name)
|
||||||
logger.info(
|
|
||||||
f"Imported: {file_info['rel_path']}/{file_info['file_name']} -> {table_name}"
|
|
||||||
)
|
|
||||||
break
|
break
|
||||||
|
|
||||||
except csv.Error as e:
|
except csv.Error as e:
|
||||||
# Fallback for malformed CSV files
|
# Fallback for malformed CSV files
|
||||||
logger.warning(f"CSV parsing error ({encoding}): {str(e)}. Trying simple split")
|
|
||||||
self._import_with_simple_split(
|
self._import_with_simple_split(
|
||||||
cursor, table_name, content, results, file_info
|
cursor, table_name, content, results, file_info
|
||||||
)
|
)
|
||||||
@@ -374,8 +419,8 @@ class BenchmarkDataManager(BaseComponent):
|
|||||||
else:
|
else:
|
||||||
# All encodings failed - try binary mode as last resort
|
# All encodings failed - try binary mode as last resort
|
||||||
try:
|
try:
|
||||||
with open(file_info['full_path'], 'rb') as f:
|
with open(file_info["full_path"], "rb") as f:
|
||||||
content = f.read().decode('ascii', errors='ignore')
|
content = f.read().decode("ascii", errors="ignore")
|
||||||
if content.strip():
|
if content.strip():
|
||||||
self._import_with_simple_split(
|
self._import_with_simple_split(
|
||||||
cursor, table_name, content, results, file_info
|
cursor, table_name, content, results, file_info
|
||||||
@@ -383,28 +428,32 @@ class BenchmarkDataManager(BaseComponent):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("File is empty or unreadable")
|
raise ValueError("File is empty or unreadable")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
results['failed'] += 1
|
results["failed"] += 1
|
||||||
logger.error(f"Failed to process {file_info['file_name']}: {str(e)}")
|
logger.error(
|
||||||
|
f"Failed to process {file_info['file_name']}: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
results['failed'] += 1
|
results["failed"] += 1
|
||||||
logger.error(f"Failed to process {file_info['full_path']}: {str(e)}")
|
logger.error(f"Failed to process {file_info['full_path']}: {str(e)}")
|
||||||
|
|
||||||
self._db_conn.commit()
|
self._db_conn.commit()
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def _import_with_simple_split(self, cursor, table_name, content, results, file_info):
|
def _import_with_simple_split(
|
||||||
|
self, cursor, table_name, content, results, file_info
|
||||||
|
):
|
||||||
"""Fallback method for malformed CSV files"""
|
"""Fallback method for malformed CSV files"""
|
||||||
# Normalize line endings
|
# Normalize line endings
|
||||||
content = content.replace('\r\n', '\n').replace('\r', '\n')
|
content = content.replace("\r\n", "\n").replace("\r", "\n")
|
||||||
lines = [line for line in content.split('\n') if line.strip()]
|
lines = [line for line in content.split("\n") if line.strip()]
|
||||||
|
|
||||||
if not lines:
|
if not lines:
|
||||||
raise ValueError("No data found after cleaning")
|
raise ValueError("No data found after cleaning")
|
||||||
|
|
||||||
# Determine delimiter
|
# Determine delimiter
|
||||||
first_line = lines[0]
|
first_line = lines[0]
|
||||||
delimiter = ',' if ',' in first_line else '\t' if '\t' in first_line else ';'
|
delimiter = "," if "," in first_line else "\t" if "\t" in first_line else ";"
|
||||||
|
|
||||||
# Process header
|
# Process header
|
||||||
headers = first_line.split(delimiter)
|
headers = first_line.split(delimiter)
|
||||||
@@ -413,14 +462,14 @@ class BenchmarkDataManager(BaseComponent):
|
|||||||
# Create table
|
# Create table
|
||||||
create_sql = f"""
|
create_sql = f"""
|
||||||
CREATE TABLE IF NOT EXISTS {table_name} (
|
CREATE TABLE IF NOT EXISTS {table_name} (
|
||||||
{', '.join([f'col_{i} TEXT' for i in range(actual_columns)])}
|
{", ".join([f"col_{i} TEXT" for i in range(actual_columns)])}
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
cursor.execute(create_sql)
|
cursor.execute(create_sql)
|
||||||
|
|
||||||
# Prepare insert
|
# Prepare insert
|
||||||
insert_sql = f"""
|
insert_sql = f"""
|
||||||
INSERT INTO {table_name} VALUES ({', '.join(['?'] * actual_columns)})
|
INSERT INTO {table_name} VALUES ({", ".join(["?"] * actual_columns)})
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Process data
|
# Process data
|
||||||
@@ -428,9 +477,6 @@ class BenchmarkDataManager(BaseComponent):
|
|||||||
for line in lines[1:]: # skip header
|
for line in lines[1:]: # skip header
|
||||||
row = line.split(delimiter)
|
row = line.split(delimiter)
|
||||||
if len(row) != actual_columns:
|
if len(row) != actual_columns:
|
||||||
logger.warning(
|
|
||||||
f"Adjusting row with {len(row)} columns to match {actual_columns} columns"
|
|
||||||
)
|
|
||||||
if len(row) < actual_columns:
|
if len(row) < actual_columns:
|
||||||
row += [None] * (actual_columns - len(row))
|
row += [None] * (actual_columns - len(row))
|
||||||
else:
|
else:
|
||||||
@@ -444,11 +490,8 @@ class BenchmarkDataManager(BaseComponent):
|
|||||||
if batch:
|
if batch:
|
||||||
cursor.executemany(insert_sql, batch)
|
cursor.executemany(insert_sql, batch)
|
||||||
|
|
||||||
results['successful'] += 1
|
results["successful"] += 1
|
||||||
results['tables_created'].append(table_name)
|
results["tables_created"].append(table_name)
|
||||||
logger.info(
|
|
||||||
f"Imported (simple split): {file_info['rel_path']}/{file_info['file_name']} -> {table_name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _cleanup_temp_dir(self):
|
def _cleanup_temp_dir(self):
|
||||||
"""Clean up temporary directory"""
|
"""Clean up temporary directory"""
|
||||||
@@ -463,7 +506,9 @@ class BenchmarkDataManager(BaseComponent):
|
|||||||
_SYSTEM_APP: Optional[SystemApp] = None
|
_SYSTEM_APP: Optional[SystemApp] = None
|
||||||
|
|
||||||
|
|
||||||
def initialize_benchmark_data(system_app: SystemApp, config: Optional[BenchmarkDataConfig] = None):
|
def initialize_benchmark_data(
|
||||||
|
system_app: SystemApp, config: Optional[BenchmarkDataConfig] = None
|
||||||
|
):
|
||||||
"""Initialize benchmark data manager component"""
|
"""Initialize benchmark data manager component"""
|
||||||
global _SYSTEM_APP
|
global _SYSTEM_APP
|
||||||
_SYSTEM_APP = system_app
|
_SYSTEM_APP = system_app
|
||||||
@@ -472,7 +517,9 @@ def initialize_benchmark_data(system_app: SystemApp, config: Optional[BenchmarkD
|
|||||||
return manager
|
return manager
|
||||||
|
|
||||||
|
|
||||||
def get_benchmark_manager(system_app: Optional[SystemApp] = None) -> BenchmarkDataManager:
|
def get_benchmark_manager(
|
||||||
|
system_app: Optional[SystemApp] = None,
|
||||||
|
) -> BenchmarkDataManager:
|
||||||
"""Get the benchmark data manager instance"""
|
"""Get the benchmark data manager instance"""
|
||||||
if not _SYSTEM_APP:
|
if not _SYSTEM_APP:
|
||||||
if not system_app:
|
if not system_app:
|
||||||
|
Reference in New Issue
Block a user