From bc140848264858d84e71bc335fc794d6459b2d3c Mon Sep 17 00:00:00 2001 From: yaoyifan-yyf Date: Thu, 25 Sep 2025 13:37:57 +0800 Subject: [PATCH] opt: code format --- .../dbgpt-app/src/dbgpt_app/dbgpt_server.py | 17 +- .../fetchdata/benchmark_data_manager.py | 225 +++++++++++------- 2 files changed, 143 insertions(+), 99 deletions(-) diff --git a/packages/dbgpt-app/src/dbgpt_app/dbgpt_server.py b/packages/dbgpt-app/src/dbgpt_app/dbgpt_server.py index 54e592c77..cc9413a33 100644 --- a/packages/dbgpt-app/src/dbgpt_app/dbgpt_server.py +++ b/packages/dbgpt-app/src/dbgpt_app/dbgpt_server.py @@ -36,7 +36,9 @@ from dbgpt_app.base import ( from dbgpt_app.component_configs import initialize_components from dbgpt_app.config import ApplicationConfig, ServiceWebParameters, SystemParameters from dbgpt_serve.core import add_exception_handler -from dbgpt_serve.evaluate.service.fetchdata.benchmark_data_manager import BenchmarkDataManager, get_benchmark_manager +from dbgpt_serve.evaluate.service.fetchdata.benchmark_data_manager import ( + get_benchmark_manager, +) logger = logging.getLogger(__name__) ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -153,8 +155,6 @@ def initialize_app(param: ApplicationConfig, args: List[str] = None): else: loop.run_until_complete(load_benchmark_data()) - - binding_port = web_config.port binding_host = web_config.host if not web_config.light: @@ -341,8 +341,7 @@ async def load_benchmark_data(): async with manager: logger.info("Fetching data from GitHub repository...") result = await manager.load_from_github( - repo_url="https://github.com/inclusionAI/Falcon", - data_dir="data/source" + repo_url="https://github.com/inclusionAI/Falcon", data_dir="data/source" ) # Log detailed results @@ -351,22 +350,20 @@ async def load_benchmark_data(): logger.info(f"Successfully imported: {result['successful']}") logger.info(f"Failed imports: {result['failed']}") - if result['failed'] > 0: + if result["failed"] > 0: logger.warning(f"Encountered {result['failed']} failures during import") # Verify the loaded data table_info = await manager.get_table_info() logger.info(f"Loaded {len(table_info)} tables into database") - return { - 'import_result': result, - 'table_info': table_info - } + return {"import_result": result, "table_info": table_info} except Exception as e: logger.error("Failed to load benchmark data", exc_info=True) raise RuntimeError(f"Benchmark data loading failed: {str(e)}") from e + if __name__ == "__main__": # Parse command line arguments _args = parse_args() diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/fetchdata/benchmark_data_manager.py b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/fetchdata/benchmark_data_manager.py index ae0ba5c3f..05c109a30 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/fetchdata/benchmark_data_manager.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/fetchdata/benchmark_data_manager.py @@ -1,20 +1,21 @@ -import os -import csv -import sqlite3 -import aiohttp import asyncio -import zipfile -import tempfile -import shutil -import time +import csv import hashlib -from pathlib import Path -from typing import List, Dict, Optional, Type -import logging import json +import logging +import os +import shutil +import sqlite3 +import tempfile +import time +import zipfile +from pathlib import Path +from typing import Dict, List, Optional + +import aiohttp -from dbgpt.component import BaseComponent, ComponentType, SystemApp from dbgpt._private.pydantic import BaseModel, ConfigDict +from dbgpt.component import BaseComponent, ComponentType, SystemApp logger = logging.getLogger(__name__) @@ -35,7 +36,9 @@ class BenchmarkDataManager(BaseComponent): name = ComponentType.BENCHMARK_DATA_MANAGER - def __init__(self, system_app: SystemApp, config: Optional[BenchmarkDataConfig] = None): + def __init__( + self, system_app: SystemApp, config: Optional[BenchmarkDataConfig] = None + ): super().__init__(system_app) self._config = config or BenchmarkDataConfig() self._http_session = None @@ -83,7 +86,9 @@ class BenchmarkDataManager(BaseComponent): columns = [col[0] for col in cursor.description] return [dict(zip(columns, row)) for row in cursor.fetchall()] - async def load_from_github(self, repo_url: str, data_dir: str = "data/source") -> Dict: + async def load_from_github( + self, repo_url: str, data_dir: str = "data/source" + ) -> Dict: """Main method to load data from GitHub repository""" try: # 1. Download or use cached repository @@ -97,7 +102,6 @@ class BenchmarkDataManager(BaseComponent): # 3. Import to SQLite result = await self._import_to_database(csv_files) - logger.info(f"Import completed: {result['successful']} succeeded, {result['failed']} failed") return result except Exception as e: @@ -123,8 +127,8 @@ class BenchmarkDataManager(BaseComponent): columns = cursor.fetchall() result[table_name] = { - 'row_count': row_count, - 'columns': [{'name': col[1], 'type': col[2]} for col in columns] + "row_count": row_count, + "columns": [{"name": col[1], "type": col[2]} for col in columns], } return result @@ -144,15 +148,19 @@ class BenchmarkDataManager(BaseComponent): def _load_mappings(self) -> Dict[str, str]: """Load table name mappings from config file""" - if not self._config.table_mapping_file or not os.path.exists(self._config.table_mapping_file): - logger.warning(f"Table mapping file not found: {self._config.table_mapping_file}") + if not self._config.table_mapping_file or not os.path.exists( + self._config.table_mapping_file + ): + logger.warning( + f"Table mapping file not found: {self._config.table_mapping_file}" + ) return {} try: - with open(self._config.table_mapping_file, 'r', encoding='utf-8') as f: + with open(self._config.table_mapping_file, "r", encoding="utf-8") as f: mapping = json.load(f) return { - key: value.split('.')[-1] if '.' in value else value + key: value.split(".")[-1] if "." in value else value for key, value in mapping.items() } except Exception as e: @@ -164,11 +172,28 @@ class BenchmarkDataManager(BaseComponent): mapped_name = self._table_mappings.get(name.lower(), name) # Clean special characters - invalid_chars = ['-', ' ', '.', ',', ';', ':', '!', '?', "'", '"', '(', ')', '[', ']', '{', '}'] + invalid_chars = [ + "-", + " ", + ".", + ",", + ";", + ":", + "!", + "?", + "'", + '"', + "(", + ")", + "[", + "]", + "{", + "}", + ] for char in invalid_chars: - mapped_name = mapped_name.replace(char, '_') - while '__' in mapped_name: - mapped_name = mapped_name.replace('__', '_') + mapped_name = mapped_name.replace(char, "_") + while "__" in mapped_name: + mapped_name = mapped_name.replace("__", "_") return mapped_name.lower() @@ -183,15 +208,17 @@ class BenchmarkDataManager(BaseComponent): # Download fresh copy self.temp_dir = tempfile.mkdtemp() - zip_url = repo_url.replace('github.com', 'api.github.com/repos') + '/zipball/main' + zip_url = ( + repo_url.replace("github.com", "api.github.com/repos") + "/zipball/main" + ) logger.info(f"Downloading from GitHub repo: {zip_url}") try: async with self._http_session.get(zip_url) as response: response.raise_for_status() - zip_path = os.path.join(self.temp_dir, 'repo.zip') + zip_path = os.path.join(self.temp_dir, "repo.zip") - with open(zip_path, 'wb') as f: + with open(zip_path, "wb") as f: while True: chunk = await response.content.read(1024) if not chunk: @@ -210,7 +237,7 @@ class BenchmarkDataManager(BaseComponent): def _get_cache_path(self, repo_url: str) -> str: """Get path to cached zip file""" - cache_key = hashlib.md5(repo_url.encode('utf-8')).hexdigest() + cache_key = hashlib.md5(repo_url.encode("utf-8")).hexdigest() return os.path.join(self._config.cache_dir, f"{cache_key}.zip") def _is_cache_valid(self, cache_path: str) -> bool: @@ -227,11 +254,14 @@ class BenchmarkDataManager(BaseComponent): def _extract_zip(self, zip_path: str) -> str: """Extract zip to temp directory""" - with zipfile.ZipFile(zip_path, 'r') as zip_ref: + with zipfile.ZipFile(zip_path, "r") as zip_ref: zip_ref.extractall(self.temp_dir) - extracted_dirs = [d for d in os.listdir(self.temp_dir) - if os.path.isdir(os.path.join(self.temp_dir, d))] + extracted_dirs = [ + d + for d in os.listdir(self.temp_dir) + if os.path.isdir(os.path.join(self.temp_dir, d)) + ] if not extracted_dirs: raise ValueError("No valid directory found after extraction") return os.path.join(self.temp_dir, extracted_dirs[0]) @@ -245,13 +275,15 @@ class BenchmarkDataManager(BaseComponent): csv_files = [] for root, _, files in os.walk(full_search_dir): for file in files: - if file.lower().endswith('.csv'): + if file.lower().endswith(".csv"): rel_path = os.path.relpath(root, start=base_dir) - csv_files.append({ - 'full_path': os.path.join(root, file), - 'rel_path': rel_path, - 'file_name': file - }) + csv_files.append( + { + "full_path": os.path.join(root, file), + "rel_path": rel_path, + "file_name": file, + } + ) return csv_files async def _import_to_database(self, csv_files: List[Dict]) -> Dict: @@ -259,24 +291,24 @@ class BenchmarkDataManager(BaseComponent): conn = await self.get_connection() cursor = conn.cursor() results = { - 'total_files': len(csv_files), - 'successful': 0, - 'failed': 0, - 'tables_created': [] + "total_files": len(csv_files), + "successful": 0, + "failed": 0, + "tables_created": [], } for file_info in csv_files: try: - path_parts = [p for p in file_info['rel_path'].split(os.sep) if p] - table_name = '_'.join(path_parts + [Path(file_info['file_name']).stem]) + path_parts = [p for p in file_info["rel_path"].split(os.sep) if p] + table_name = "_".join(path_parts + [Path(file_info["file_name"]).stem]) table_name = self._sanitize_table_name(table_name) # Try multiple encodings - encodings = ['utf-8-sig', 'utf-8', 'latin-1', 'iso-8859-1', 'cp1252'] + encodings = ["utf-8-sig", "utf-8", "latin-1", "iso-8859-1", "cp1252"] for encoding in encodings: try: - with open(file_info['full_path'], 'r', encoding=encoding) as f: + with open(file_info["full_path"], "r", encoding=encoding) as f: content = f.read() # Handle empty files @@ -284,12 +316,13 @@ class BenchmarkDataManager(BaseComponent): raise ValueError("File is empty") # Replace problematic line breaks if needed - content = content.replace('\r\n', '\n').replace('\r', '\n') + content = content.replace("\r\n", "\n").replace("\r", "\n") # Split into lines - lines = [line for line in content.split('\n') if line.strip()] + lines = [ + line for line in content.split("\n") if line.strip() + ] - # Parse header and first data line to detect actual structure try: header_line = lines[0] data_line = lines[1] if len(lines) > 1 else "" @@ -300,27 +333,46 @@ class BenchmarkDataManager(BaseComponent): has_header = sniffer.has_header(content[:1024]) if has_header: - headers = list(csv.reader([header_line], dialect))[0] - first_data_row = list(csv.reader([data_line], dialect))[0] if data_line else [] + headers = list(csv.reader([header_line], dialect))[ + 0 + ] + first_data_row = ( + list(csv.reader([data_line], dialect))[0] + if data_line + else [] + ) else: - headers = list(csv.reader([header_line], dialect))[0] + headers = list(csv.reader([header_line], dialect))[ + 0 + ] first_data_row = headers # first line is data - headers = [f'col_{i}' for i in range(len(headers))] + headers = [f"col_{i}" for i in range(len(headers))] # Determine actual number of columns from data - actual_columns = len(first_data_row) if first_data_row else len(headers) + actual_columns = ( + len(first_data_row) + if first_data_row + else len(headers) + ) # Create table with correct number of columns create_sql = f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - {', '.join([f'"{h}" TEXT' for h in headers[:actual_columns]])} - ) + CREATE TABLE IF NOT EXISTS {table_name} ({ + ", ".join( + [ + f'"{h}" TEXT' + for h in headers[:actual_columns] + ] + ) + }) """ cursor.execute(create_sql) # Prepare insert statement insert_sql = f""" - INSERT INTO {table_name} VALUES ({', '.join(['?'] * actual_columns)}) + INSERT INTO {table_name} VALUES ({ + ", ".join(["?"] * actual_columns) + }) """ # Process data @@ -335,9 +387,6 @@ class BenchmarkDataManager(BaseComponent): # Ensure row has correct number of columns if len(row) != actual_columns: - logger.warning( - f"Adjusting row with {len(row)} columns to match {actual_columns} columns" - ) if len(row) < actual_columns: row += [None] * (actual_columns - len(row)) else: @@ -351,16 +400,12 @@ class BenchmarkDataManager(BaseComponent): if batch: cursor.executemany(insert_sql, batch) - results['successful'] += 1 - results['tables_created'].append(table_name) - logger.info( - f"Imported: {file_info['rel_path']}/{file_info['file_name']} -> {table_name}" - ) + results["successful"] += 1 + results["tables_created"].append(table_name) break except csv.Error as e: # Fallback for malformed CSV files - logger.warning(f"CSV parsing error ({encoding}): {str(e)}. Trying simple split") self._import_with_simple_split( cursor, table_name, content, results, file_info ) @@ -374,8 +419,8 @@ class BenchmarkDataManager(BaseComponent): else: # All encodings failed - try binary mode as last resort try: - with open(file_info['full_path'], 'rb') as f: - content = f.read().decode('ascii', errors='ignore') + with open(file_info["full_path"], "rb") as f: + content = f.read().decode("ascii", errors="ignore") if content.strip(): self._import_with_simple_split( cursor, table_name, content, results, file_info @@ -383,28 +428,32 @@ class BenchmarkDataManager(BaseComponent): else: raise ValueError("File is empty or unreadable") except Exception as e: - results['failed'] += 1 - logger.error(f"Failed to process {file_info['file_name']}: {str(e)}") + results["failed"] += 1 + logger.error( + f"Failed to process {file_info['file_name']}: {str(e)}" + ) except Exception as e: - results['failed'] += 1 + results["failed"] += 1 logger.error(f"Failed to process {file_info['full_path']}: {str(e)}") self._db_conn.commit() return results - def _import_with_simple_split(self, cursor, table_name, content, results, file_info): + def _import_with_simple_split( + self, cursor, table_name, content, results, file_info + ): """Fallback method for malformed CSV files""" # Normalize line endings - content = content.replace('\r\n', '\n').replace('\r', '\n') - lines = [line for line in content.split('\n') if line.strip()] + content = content.replace("\r\n", "\n").replace("\r", "\n") + lines = [line for line in content.split("\n") if line.strip()] if not lines: raise ValueError("No data found after cleaning") # Determine delimiter first_line = lines[0] - delimiter = ',' if ',' in first_line else '\t' if '\t' in first_line else ';' + delimiter = "," if "," in first_line else "\t" if "\t" in first_line else ";" # Process header headers = first_line.split(delimiter) @@ -413,14 +462,14 @@ class BenchmarkDataManager(BaseComponent): # Create table create_sql = f""" CREATE TABLE IF NOT EXISTS {table_name} ( - {', '.join([f'col_{i} TEXT' for i in range(actual_columns)])} + {", ".join([f"col_{i} TEXT" for i in range(actual_columns)])} ) """ cursor.execute(create_sql) # Prepare insert insert_sql = f""" - INSERT INTO {table_name} VALUES ({', '.join(['?'] * actual_columns)}) + INSERT INTO {table_name} VALUES ({", ".join(["?"] * actual_columns)}) """ # Process data @@ -428,9 +477,6 @@ class BenchmarkDataManager(BaseComponent): for line in lines[1:]: # skip header row = line.split(delimiter) if len(row) != actual_columns: - logger.warning( - f"Adjusting row with {len(row)} columns to match {actual_columns} columns" - ) if len(row) < actual_columns: row += [None] * (actual_columns - len(row)) else: @@ -444,11 +490,8 @@ class BenchmarkDataManager(BaseComponent): if batch: cursor.executemany(insert_sql, batch) - results['successful'] += 1 - results['tables_created'].append(table_name) - logger.info( - f"Imported (simple split): {file_info['rel_path']}/{file_info['file_name']} -> {table_name}" - ) + results["successful"] += 1 + results["tables_created"].append(table_name) def _cleanup_temp_dir(self): """Clean up temporary directory""" @@ -463,7 +506,9 @@ class BenchmarkDataManager(BaseComponent): _SYSTEM_APP: Optional[SystemApp] = None -def initialize_benchmark_data(system_app: SystemApp, config: Optional[BenchmarkDataConfig] = None): +def initialize_benchmark_data( + system_app: SystemApp, config: Optional[BenchmarkDataConfig] = None +): """Initialize benchmark data manager component""" global _SYSTEM_APP _SYSTEM_APP = system_app @@ -472,10 +517,12 @@ def initialize_benchmark_data(system_app: SystemApp, config: Optional[BenchmarkD return manager -def get_benchmark_manager(system_app: Optional[SystemApp] = None) -> BenchmarkDataManager: +def get_benchmark_manager( + system_app: Optional[SystemApp] = None, +) -> BenchmarkDataManager: """Get the benchmark data manager instance""" if not _SYSTEM_APP: if not system_app: system_app = SystemApp() initialize_benchmark_data(system_app) - return BenchmarkDataManager.get_instance(system_app) \ No newline at end of file + return BenchmarkDataManager.get_instance(system_app)