feat: async loading benchmark data

This commit is contained in:
yaoyifan-yyf
2025-09-25 11:26:13 +08:00
parent a0f413b915
commit 9e50bff12c
7 changed files with 614 additions and 0 deletions

View File

@@ -1,3 +1,4 @@
import asyncio
import logging
import os
import sys
@@ -35,6 +36,7 @@ from dbgpt_app.base import (
from dbgpt_app.component_configs import initialize_components
from dbgpt_app.config import ApplicationConfig, ServiceWebParameters, SystemParameters
from dbgpt_serve.core import add_exception_handler
from dbgpt_serve.evaluate.service.fetchdata.benchmark_data_manager import BenchmarkDataManager, get_benchmark_manager
logger = logging.getLogger(__name__)
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
@@ -144,6 +146,15 @@ def initialize_app(param: ApplicationConfig, args: List[str] = None):
# After init, when the database is ready
system_app.after_init()
# Async fetch benchmark dataset from Falcon
loop = asyncio.get_event_loop()
if loop.is_running():
loop.create_task(load_benchmark_data())
else:
loop.run_until_complete(load_benchmark_data())
binding_port = web_config.port
binding_host = web_config.host
if not web_config.light:
@@ -319,6 +330,43 @@ def parse_args():
return parser.parse_args()
async def load_benchmark_data():
"""Load benchmark data from GitHub repository into SQLite database"""
logging.basicConfig(level=logging.INFO)
logger.info("Starting benchmark data loading process...")
try:
manager = get_benchmark_manager(system_app)
async with manager:
logger.info("Fetching data from GitHub repository...")
result = await manager.load_from_github(
repo_url="https://github.com/inclusionAI/Falcon",
data_dir="data/source"
)
# Log detailed results
logger.info("\nBenchmark Data Loading Summary:")
logger.info(f"Total CSV files processed: {result['total_files']}")
logger.info(f"Successfully imported: {result['successful']}")
logger.info(f"Failed imports: {result['failed']}")
if result['failed'] > 0:
logger.warning(f"Encountered {result['failed']} failures during import")
# Verify the loaded data
table_info = await manager.get_table_info()
logger.info(f"Loaded {len(table_info)} tables into database")
return {
'import_result': result,
'table_info': table_info
}
except Exception as e:
logger.error("Failed to load benchmark data", exc_info=True)
raise RuntimeError(f"Benchmark data loading failed: {str(e)}") from e
if __name__ == "__main__":
# Parse command line arguments
_args = parse_args()

View File

@@ -101,6 +101,7 @@ class ComponentType(str, Enum):
RESOURCE_MANAGER = "dbgpt_resource_manager"
VARIABLES_PROVIDER = "dbgpt_variables_provider"
FILE_STORAGE_CLIENT = "dbgpt_file_storage_client"
BENCHMARK_DATA_MANAGER = "dbgpt_benchmark_data_manager"
_EMPTY_DEFAULT_COMPONENT = "_EMPTY_DEFAULT_COMPONENT"

View File

@@ -0,0 +1,481 @@
import os
import csv
import sqlite3
import aiohttp
import asyncio
import zipfile
import tempfile
import shutil
import time
import hashlib
from pathlib import Path
from typing import List, Dict, Optional, Type
import logging
import json
from dbgpt.component import BaseComponent, ComponentType, SystemApp
from dbgpt._private.pydantic import BaseModel, ConfigDict
logger = logging.getLogger(__name__)
class BenchmarkDataConfig(BaseModel):
"""Configuration for Benchmark Data Manager"""
model_config = ConfigDict(arbitrary_types_allowed=True)
cache_dir: str = "cache"
db_path: str = "benchmark_data.db"
table_mapping_file: Optional[str] = None
cache_expiry_days: int = 1
class BenchmarkDataManager(BaseComponent):
"""Manage benchmark data lifecycle including fetching, transformation and storage"""
name = ComponentType.BENCHMARK_DATA_MANAGER
def __init__(self, system_app: SystemApp, config: Optional[BenchmarkDataConfig] = None):
super().__init__(system_app)
self._config = config or BenchmarkDataConfig()
self._http_session = None
self._db_conn = None
self._table_mappings = self._load_mappings()
self._lock = asyncio.Lock()
self.temp_dir = None
# Ensure directories exist
os.makedirs(self._config.cache_dir, exist_ok=True)
def init_app(self, system_app: SystemApp):
"""Initialize the AgentManager."""
self.system_app = system_app
async def __aenter__(self):
self._http_session = aiohttp.ClientSession()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()
async def close(self):
"""Clean up resources"""
if self._http_session:
await self._http_session.close()
self._http_session = None
if self._db_conn:
self._db_conn.close()
self._db_conn = None
self._cleanup_temp_dir()
async def get_connection(self) -> sqlite3.Connection:
"""Get database connection (thread-safe)"""
async with self._lock:
if not self._db_conn:
self._db_conn = sqlite3.connect(self._config.db_path)
return self._db_conn
async def query(self, query: str, params: tuple = ()) -> List[Dict]:
"""Execute query and return results as dict list"""
conn = await self.get_connection()
cursor = conn.cursor()
cursor.execute(query, params)
columns = [col[0] for col in cursor.description]
return [dict(zip(columns, row)) for row in cursor.fetchall()]
async def load_from_github(self, repo_url: str, data_dir: str = "data/source") -> Dict:
"""Main method to load data from GitHub repository"""
try:
# 1. Download or use cached repository
repo_dir = await self._download_repo_contents(repo_url)
# 2. Find all CSV files recursively
csv_files = self._discover_csv_files(repo_dir, data_dir)
if not csv_files:
raise ValueError("No CSV files found")
logger.info(f"Found {len(csv_files)} CSV files")
# 3. Import to SQLite
result = await self._import_to_database(csv_files)
logger.info(f"Import completed: {result['successful']} succeeded, {result['failed']} failed")
return result
except Exception as e:
logger.error(f"Import failed: {str(e)}")
raise
finally:
self._cleanup_temp_dir()
async def get_table_info(self) -> Dict:
"""Get metadata about all tables"""
conn = await self.get_connection()
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
tables = cursor.fetchall()
result = {}
for table in tables:
table_name = table[0]
cursor.execute(f"SELECT COUNT(*) FROM {table_name}")
row_count = cursor.fetchone()[0]
cursor.execute(f"PRAGMA table_info({table_name})")
columns = cursor.fetchall()
result[table_name] = {
'row_count': row_count,
'columns': [{'name': col[1], 'type': col[2]} for col in columns]
}
return result
def clear_cache(self):
"""Clear cached repository files"""
try:
for filename in os.listdir(self._config.cache_dir):
file_path = os.path.join(self._config.cache_dir, filename)
try:
if os.path.isfile(file_path):
os.unlink(file_path)
except Exception as e:
logger.error(f"Failed to delete {file_path}: {str(e)}")
logger.info("Cache cleared successfully")
except Exception as e:
logger.error(f"Failed to clear cache: {str(e)}")
def _load_mappings(self) -> Dict[str, str]:
"""Load table name mappings from config file"""
if not self._config.table_mapping_file or not os.path.exists(self._config.table_mapping_file):
logger.warning(f"Table mapping file not found: {self._config.table_mapping_file}")
return {}
try:
with open(self._config.table_mapping_file, 'r', encoding='utf-8') as f:
mapping = json.load(f)
return {
key: value.split('.')[-1] if '.' in value else value
for key, value in mapping.items()
}
except Exception as e:
logger.error(f"Failed to load table mapping: {str(e)}")
return {}
def _sanitize_table_name(self, name: str) -> str:
"""Normalize table names using mappings"""
mapped_name = self._table_mappings.get(name.lower(), name)
# Clean special characters
invalid_chars = ['-', ' ', '.', ',', ';', ':', '!', '?', "'", '"', '(', ')', '[', ']', '{', '}']
for char in invalid_chars:
mapped_name = mapped_name.replace(char, '_')
while '__' in mapped_name:
mapped_name = mapped_name.replace('__', '_')
return mapped_name.lower()
async def _download_repo_contents(self, repo_url: str) -> str:
"""Download repository with caching"""
cache_path = self._get_cache_path(repo_url)
# Use cache if valid
if os.path.exists(cache_path) and self._is_cache_valid(cache_path):
logger.info(f"Using cached repository: {cache_path}")
return self._extract_cache(cache_path)
# Download fresh copy
self.temp_dir = tempfile.mkdtemp()
zip_url = repo_url.replace('github.com', 'api.github.com/repos') + '/zipball/main'
logger.info(f"Downloading from GitHub repo: {zip_url}")
try:
async with self._http_session.get(zip_url) as response:
response.raise_for_status()
zip_path = os.path.join(self.temp_dir, 'repo.zip')
with open(zip_path, 'wb') as f:
while True:
chunk = await response.content.read(1024)
if not chunk:
break
f.write(chunk)
# Cache the download
shutil.copy2(zip_path, cache_path)
logger.info(f"Saved repository to cache: {cache_path}")
return self._extract_zip(zip_path)
except Exception as e:
self._cleanup_temp_dir()
raise RuntimeError(f"Failed to download repository: {str(e)}")
def _get_cache_path(self, repo_url: str) -> str:
"""Get path to cached zip file"""
cache_key = hashlib.md5(repo_url.encode('utf-8')).hexdigest()
return os.path.join(self._config.cache_dir, f"{cache_key}.zip")
def _is_cache_valid(self, cache_path: str) -> bool:
"""Check if cache is still valid"""
if not os.path.exists(cache_path):
return False
file_age = time.time() - os.path.getmtime(cache_path)
return file_age < (self._config.cache_expiry_days * 24 * 60 * 60)
def _extract_cache(self, cache_path: str) -> str:
"""Extract cached repository"""
self.temp_dir = tempfile.mkdtemp()
return self._extract_zip(cache_path)
def _extract_zip(self, zip_path: str) -> str:
"""Extract zip to temp directory"""
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(self.temp_dir)
extracted_dirs = [d for d in os.listdir(self.temp_dir)
if os.path.isdir(os.path.join(self.temp_dir, d))]
if not extracted_dirs:
raise ValueError("No valid directory found after extraction")
return os.path.join(self.temp_dir, extracted_dirs[0])
def _discover_csv_files(self, base_dir: str, search_dir: str) -> List[Dict]:
"""Find all CSV files recursively"""
full_search_dir = os.path.join(base_dir, search_dir) if search_dir else base_dir
if not os.path.exists(full_search_dir):
raise ValueError(f"Directory not found: {full_search_dir}")
csv_files = []
for root, _, files in os.walk(full_search_dir):
for file in files:
if file.lower().endswith('.csv'):
rel_path = os.path.relpath(root, start=base_dir)
csv_files.append({
'full_path': os.path.join(root, file),
'rel_path': rel_path,
'file_name': file
})
return csv_files
async def _import_to_database(self, csv_files: List[Dict]) -> Dict:
"""Import CSV data to SQLite"""
conn = await self.get_connection()
cursor = conn.cursor()
results = {
'total_files': len(csv_files),
'successful': 0,
'failed': 0,
'tables_created': []
}
for file_info in csv_files:
try:
path_parts = [p for p in file_info['rel_path'].split(os.sep) if p]
table_name = '_'.join(path_parts + [Path(file_info['file_name']).stem])
table_name = self._sanitize_table_name(table_name)
# Try multiple encodings
encodings = ['utf-8-sig', 'utf-8', 'latin-1', 'iso-8859-1', 'cp1252']
for encoding in encodings:
try:
with open(file_info['full_path'], 'r', encoding=encoding) as f:
content = f.read()
# Handle empty files
if not content.strip():
raise ValueError("File is empty")
# Replace problematic line breaks if needed
content = content.replace('\r\n', '\n').replace('\r', '\n')
# Split into lines
lines = [line for line in content.split('\n') if line.strip()]
# Parse header and first data line to detect actual structure
try:
header_line = lines[0]
data_line = lines[1] if len(lines) > 1 else ""
# Detect delimiter (comma, semicolon, tab)
sniffer = csv.Sniffer()
dialect = sniffer.sniff(header_line)
has_header = sniffer.has_header(content[:1024])
if has_header:
headers = list(csv.reader([header_line], dialect))[0]
first_data_row = list(csv.reader([data_line], dialect))[0] if data_line else []
else:
headers = list(csv.reader([header_line], dialect))[0]
first_data_row = headers # first line is data
headers = [f'col_{i}' for i in range(len(headers))]
# Determine actual number of columns from data
actual_columns = len(first_data_row) if first_data_row else len(headers)
# Create table with correct number of columns
create_sql = f"""
CREATE TABLE IF NOT EXISTS {table_name} (
{', '.join([f'"{h}" TEXT' for h in headers[:actual_columns]])}
)
"""
cursor.execute(create_sql)
# Prepare insert statement
insert_sql = f"""
INSERT INTO {table_name} VALUES ({', '.join(['?'] * actual_columns)})
"""
# Process data
batch = []
reader = csv.reader(lines, dialect)
if has_header:
next(reader) # skip header
for row in reader:
if not row: # skip empty rows
continue
# Ensure row has correct number of columns
if len(row) != actual_columns:
logger.warning(
f"Adjusting row with {len(row)} columns to match {actual_columns} columns"
)
if len(row) < actual_columns:
row += [None] * (actual_columns - len(row))
else:
row = row[:actual_columns]
batch.append(row)
if len(batch) >= 1000:
cursor.executemany(insert_sql, batch)
batch = []
if batch:
cursor.executemany(insert_sql, batch)
results['successful'] += 1
results['tables_created'].append(table_name)
logger.info(
f"Imported: {file_info['rel_path']}/{file_info['file_name']} -> {table_name}"
)
break
except csv.Error as e:
# Fallback for malformed CSV files
logger.warning(f"CSV parsing error ({encoding}): {str(e)}. Trying simple split")
self._import_with_simple_split(
cursor, table_name, content, results, file_info
)
break
except UnicodeDecodeError:
continue
except Exception as e:
logger.warning(f"Error with encoding {encoding}: {str(e)}")
continue
else:
# All encodings failed - try binary mode as last resort
try:
with open(file_info['full_path'], 'rb') as f:
content = f.read().decode('ascii', errors='ignore')
if content.strip():
self._import_with_simple_split(
cursor, table_name, content, results, file_info
)
else:
raise ValueError("File is empty or unreadable")
except Exception as e:
results['failed'] += 1
logger.error(f"Failed to process {file_info['file_name']}: {str(e)}")
except Exception as e:
results['failed'] += 1
logger.error(f"Failed to process {file_info['full_path']}: {str(e)}")
self._db_conn.commit()
return results
def _import_with_simple_split(self, cursor, table_name, content, results, file_info):
"""Fallback method for malformed CSV files"""
# Normalize line endings
content = content.replace('\r\n', '\n').replace('\r', '\n')
lines = [line for line in content.split('\n') if line.strip()]
if not lines:
raise ValueError("No data found after cleaning")
# Determine delimiter
first_line = lines[0]
delimiter = ',' if ',' in first_line else '\t' if '\t' in first_line else ';'
# Process header
headers = first_line.split(delimiter)
actual_columns = len(headers)
# Create table
create_sql = f"""
CREATE TABLE IF NOT EXISTS {table_name} (
{', '.join([f'col_{i} TEXT' for i in range(actual_columns)])}
)
"""
cursor.execute(create_sql)
# Prepare insert
insert_sql = f"""
INSERT INTO {table_name} VALUES ({', '.join(['?'] * actual_columns)})
"""
# Process data
batch = []
for line in lines[1:]: # skip header
row = line.split(delimiter)
if len(row) != actual_columns:
logger.warning(
f"Adjusting row with {len(row)} columns to match {actual_columns} columns"
)
if len(row) < actual_columns:
row += [None] * (actual_columns - len(row))
else:
row = row[:actual_columns]
batch.append(row)
if len(batch) >= 1000:
cursor.executemany(insert_sql, batch)
batch = []
if batch:
cursor.executemany(insert_sql, batch)
results['successful'] += 1
results['tables_created'].append(table_name)
logger.info(
f"Imported (simple split): {file_info['rel_path']}/{file_info['file_name']} -> {table_name}"
)
def _cleanup_temp_dir(self):
"""Clean up temporary directory"""
if self.temp_dir and os.path.exists(self.temp_dir):
try:
shutil.rmtree(self.temp_dir)
self.temp_dir = None
except Exception as e:
logger.warning(f"Failed to clean temp dir: {str(e)}")
_SYSTEM_APP: Optional[SystemApp] = None
def initialize_benchmark_data(system_app: SystemApp, config: Optional[BenchmarkDataConfig] = None):
"""Initialize benchmark data manager component"""
global _SYSTEM_APP
_SYSTEM_APP = system_app
manager = BenchmarkDataManager(system_app, config)
system_app.register_instance(manager)
return manager
def get_benchmark_manager(system_app: Optional[SystemApp] = None) -> BenchmarkDataManager:
"""Get the benchmark data manager instance"""
if not _SYSTEM_APP:
if not system_app:
system_app = SystemApp()
initialize_benchmark_data(system_app)
return BenchmarkDataManager.get_instance(system_app)

84
table_mapping.json Normal file
View File

@@ -0,0 +1,84 @@
{
"data_source_20_rides_data": "ant_icube_dev.city_ride_data_rides",
"data_source_20_drivers_data": "ant_icube_dev.city_ride_data_drivers",
"data_source_18_richest_countries": "ant_icube_dev.world_economic_richest_countries",
"data_source_18_cost_of_living": "ant_icube_dev.world_economic_cost_of_living",
"data_source_18_tourism": "ant_icube_dev.world_economic_tourism",
"data_source_18_corruption": "ant_icube_dev.world_economic_corruption",
"data_source_18_unemployment": "ant_icube_dev.world_economic_unemployment",
"data_source_27_customers": "ant_icube_dev.grocery_sales_customers",
"data_source_27_categories": "ant_icube_dev.grocery_sales_categories",
"data_source_27_products": "ant_icube_dev.grocery_sales_products",
"data_source_27_countries": "ant_icube_dev.grocery_sales_countries",
"data_source_27_cities": "ant_icube_dev.grocery_sales_cities",
"data_source_27_employees": "ant_icube_dev.grocery_sales_employees",
"data_source_11_price": "ant_icube_dev.bakery_sales_price",
"data_source_11_sales": "ant_icube_dev.bakery_sales_sale",
"data_source_7_vgsales": "ant_icube_dev.di_video_game_sales",
"data_source_16_subjects": "ant_icube_dev.school_subject",
"data_source_16_marks": "ant_icube_dev.school_marks",
"data_source_16_teachers": "ant_icube_dev.school_teachers",
"data_source_16_students": "ant_icube_dev.school_students",
"data_source_6_sales_dataset": "ant_icube_dev.di_sales_dataset",
"data_source_28_customers": "ant_icube_dev.online_shop_customers",
"data_source_28_products": "ant_icube_dev.online_shop_products",
"data_source_28_reviews": "ant_icube_dev.online_shop_reviews",
"data_source_28_orders": "ant_icube_dev.online_shop_orders",
"data_source_28_shipments": "ant_icube_dev.online_shop_shipments",
"data_source_28_suppliers": "ant_icube_dev.online_shop_suppliers",
"data_source_28_payment": "ant_icube_dev.online_shop_payment",
"data_source_28_order_items": "ant_icube_dev.online_shop_order_items",
"data_source_17_df_customers": "ant_icube_dev.ecommerce_order_customers",
"data_source_17_df_products": "ant_icube_dev.ecommerce_order_products",
"data_source_17_df_payments": "ant_icube_dev.ecommerce_order_payments",
"data_source_17_df_orders": "ant_icube_dev.ecommerce_order_orders",
"data_source_17_df_orderitems": "ant_icube_dev.ecommerce_order_order_items",
"data_source_1_finance_data": "ant_icube_dev.di_finance_data",
"data_source_10_indexinfo": "ant_icube_dev.stock_exchange_index_info",
"data_source_10_indexdata": "ant_icube_dev.stock_exchange_index_data",
"data_source_19_drinks": "ant_icube_dev.alcohol_and_life_expectancy_drinks",
"data_source_19_lifeexpectancy_verbose": "ant_icube_dev.alcohol_and_life_expectancy_verbose",
"data_source_26_teams": "ant_icube_dev.football_teams",
"data_source_26_appearances": "ant_icube_dev.football_appereances",
"data_source_26_teamstats": "ant_icube_dev.football_teamstats",
"data_source_26_leagues": "ant_icube_dev.football_leagues",
"data_source_26_players": "ant_icube_dev.football_players",
"data_source_26_games": "ant_icube_dev.football_games",
"data_source_26_shots": "ant_icube_dev.football_shots",
"data_source_8_googleplaystore": "ant_icube_dev.di_google_play_store_apps",
"data_source_21_e_customers": "ant_icube_dev.di_data_cleaning_for_customer_database_e_customers",
"data_source_21_e_products": "ant_icube_dev.di_data_cleaning_for_customer_database_e_products",
"data_source_21_e_orders": "ant_icube_dev.di_data_cleaning_for_customer_database_e_orders",
"data_source_24_blinkit_products": "ant_icube_dev.blinkit_products",
"data_source_24_blinkit_marketing_performance": "ant_icube_dev.blinkit_marketing_performance",
"data_source_24_blinkit_inventory": "ant_icube_dev.blinkit_inventory",
"data_source_24_blinkit_customer_feedback": "ant_icube_dev.blinkit_customer_feedback",
"data_source_24_blinkit_inventorynew": "ant_icube_dev.blinkit_inventory",
"data_source_24_blinkit_order_items": "ant_icube_dev.blinkit_order_items",
"data_source_24_blinkit_customers": "ant_icube_dev.blinkit_customers",
"data_source_24_blinkit_orders": "ant_icube_dev.blinkit_orders",
"data_source_24_blinkit_delivery_performance": "ant_icube_dev.blinkit_delivery_performance",
"data_source_23_ben10_aliens": "ant_icube_dev.di_ben10_alien_universe_realistic_battle_dataset_aliens",
"data_source_23_ben10_enemies": "ant_icube_dev.di_ben10_alien_universe_realistic_battle_dataset_enemies",
"data_source_23_ben10_battles": "ant_icube_dev.di_ben10_alien_universe_realistic_battle_dataset_battles",
"data_source_2_finance_loan_approval_prediction_data": "ant_icube_dev.di_finance_loan_approval_prediction_data",
"data_source_13_features": "ant_icube_dev.walmart_features",
"data_source_13_stores": "ant_icube_dev.walmart_stores",
"data_source_13_sales": "ant_icube_dev.walmart_sales",
"data_source_5_unicorns_till_sep_2022": "ant_icube_dev.di_unicorn_startups",
"data_source_14_products": "ant_icube_dev.mexico_toy_products",
"data_source_14_inventory": "ant_icube_dev.mexico_toy_inventory",
"data_source_14_stores": "ant_icube_dev.mexico_toy_stores",
"data_source_14_sales": "ant_icube_dev.mexico_toy_sales",
"data_source_22_ufc_country_data": "ant_icube_dev.ufc_country_data",
"data_source_22_ufc_events_stats": "ant_icube_dev.ufc_events_stats",
"data_source_22_ufc_fighters_stats": "ant_icube_dev.ufc_fighters_stats",
"data_source_25_bakutech_bakutech_sales_data": "ant_icube_dev.tech_sales_sales_data",
"data_source_25_bakutech_bakutech_assets": "ant_icube_dev.tech_sales_assets",
"data_source_25_bakutech_bakutech_product_returns": "ant_icube_dev.tech_sales_product_returns",
"data_source_25_bakutech_bakutech_product_subcategories": "ant_icube_dev.tech_sales_product_subcategories",
"data_source_25_bakutech_bakutech_customer_lookup": "ant_icube_dev.tech_sales_customer_lookup",
"data_source_25_bakutech_bakutech_product_categories": "ant_icube_dev.tech_sales_product_categories",
"data_source_25_bakutech_bakutech_products_lookup": "ant_icube_dev.tech_sales_product_lookup",
"data_source_25_bakutech_bakutech_dates": "ant_icube_dev.tech_sales_dates"
}