diff --git a/packages/dbgpt-app/src/dbgpt_app/component_configs.py b/packages/dbgpt-app/src/dbgpt_app/component_configs.py index 623b3e996..de2c75e1b 100644 --- a/packages/dbgpt-app/src/dbgpt_app/component_configs.py +++ b/packages/dbgpt-app/src/dbgpt_app/component_configs.py @@ -59,6 +59,7 @@ def initialize_components( _initialize_code_server(system_app) # Initialize prompt templates - MUST be after serve apps registration _initialize_prompt_templates() + _initialize_benchmark_data(system_app) def _initialize_model_cache(system_app: SystemApp, web_config: ServiceWebParameters): @@ -206,3 +207,11 @@ def _initialize_prompt_templates(): logger.error(f"Failed to initialize prompt templates: {e}") # Don't raise exception to avoid breaking the application startup # The templates will be loaded lazily when needed + + +def _initialize_benchmark_data(system_app: SystemApp): + from dbgpt_serve.evaluate.service.fetchdata.benchmark_data_manager import ( + initialize_benchmark_data, + ) + + initialize_benchmark_data(system_app) diff --git a/packages/dbgpt-app/src/dbgpt_app/dbgpt_server.py b/packages/dbgpt-app/src/dbgpt_app/dbgpt_server.py index cc9413a33..02223bfd1 100644 --- a/packages/dbgpt-app/src/dbgpt_app/dbgpt_server.py +++ b/packages/dbgpt-app/src/dbgpt_app/dbgpt_server.py @@ -1,4 +1,3 @@ -import asyncio import logging import os import sys @@ -36,9 +35,6 @@ from dbgpt_app.base import ( from dbgpt_app.component_configs import initialize_components from dbgpt_app.config import ApplicationConfig, ServiceWebParameters, SystemParameters from dbgpt_serve.core import add_exception_handler -from dbgpt_serve.evaluate.service.fetchdata.benchmark_data_manager import ( - get_benchmark_manager, -) logger = logging.getLogger(__name__) ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -148,13 +144,6 @@ def initialize_app(param: ApplicationConfig, args: List[str] = None): # After init, when the database is ready system_app.after_init() - # Async fetch benchmark dataset from Falcon - loop = asyncio.get_event_loop() - if loop.is_running(): - loop.create_task(load_benchmark_data()) - else: - loop.run_until_complete(load_benchmark_data()) - binding_port = web_config.port binding_host = web_config.host if not web_config.light: @@ -330,40 +319,6 @@ def parse_args(): return parser.parse_args() -async def load_benchmark_data(): - """Load benchmark data from GitHub repository into SQLite database""" - logging.basicConfig(level=logging.INFO) - logger.info("Starting benchmark data loading process...") - - try: - manager = get_benchmark_manager(system_app) - - async with manager: - logger.info("Fetching data from GitHub repository...") - result = await manager.load_from_github( - repo_url="https://github.com/inclusionAI/Falcon", data_dir="data/source" - ) - - # Log detailed results - logger.info("\nBenchmark Data Loading Summary:") - logger.info(f"Total CSV files processed: {result['total_files']}") - logger.info(f"Successfully imported: {result['successful']}") - logger.info(f"Failed imports: {result['failed']}") - - if result["failed"] > 0: - logger.warning(f"Encountered {result['failed']} failures during import") - - # Verify the loaded data - table_info = await manager.get_table_info() - logger.info(f"Loaded {len(table_info)} tables into database") - - return {"import_result": result, "table_info": table_info} - - except Exception as e: - logger.error("Failed to load benchmark data", exc_info=True) - raise RuntimeError(f"Benchmark data loading failed: {str(e)}") from e - - if __name__ == "__main__": # Parse command line arguments _args = parse_args() diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/file_parse_service.py b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/file_parse_service.py index ec15f9fe5..b0174c4e8 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/file_parse_service.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/file_parse_service.py @@ -5,8 +5,8 @@ import os from typing import List import pandas as pd -from openpyxl.reader.excel import load_workbook +from openpyxl.reader.excel import load_workbook from dbgpt.util.benchmarks.ExcelUtils import ExcelUtils from .models import ( @@ -163,6 +163,7 @@ class FileParseService: return outputs + class ExcelFileParseService(FileParseService): def parse_input_sets(self, location: str) -> BenchmarkDataSets: """ diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/models.py b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/models.py index da7f4a585..c9398bde1 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/models.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/benchmark/models.py @@ -13,7 +13,6 @@ class BenchmarkModeTypeEnum(str, Enum): BUILD = "BUILD" EXECUTE = "EXECUTE" - @dataclass class DataCompareStrategyConfig: strategy: str # "EXACT_MATCH" | "CONTAIN_MATCH" @@ -131,6 +130,7 @@ class RoundAnswerConfirmModel: compareResult: Optional[DataCompareResultEnum] = None + class FileParseTypeEnum(Enum): """文件解析类型枚举""" @@ -158,6 +158,11 @@ class ContentTypeEnum(Enum): @dataclass class BenchmarkExecuteConfig: + benchmarkModeType: BenchmarkModeTypeEnum + compareResultEnable: bool + standardFilePath: Optional[str] = None + compareConfig: Optional[Dict[str, str]] = None + """ Benchmark Execute Config """ diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/fetchdata/benchmark_data_manager.py b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/fetchdata/benchmark_data_manager.py index a7c6bd867..141fbe3bc 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/fetchdata/benchmark_data_manager.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/fetchdata/benchmark_data_manager.py @@ -5,17 +5,18 @@ import json import logging import os import shutil -import sqlite3 import tempfile import time import zipfile from pathlib import Path -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple import aiohttp +from sqlalchemy import text from dbgpt._private.pydantic import BaseModel, ConfigDict from dbgpt.component import BaseComponent, ComponentType, SystemApp +from dbgpt_ext.datasource.rdbms.conn_sqlite import SQLiteConnector logger = logging.getLogger(__name__) @@ -26,9 +27,11 @@ class BenchmarkDataConfig(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) cache_dir: str = "cache" - db_path: str = "pilot/benchmark_meta_data/benchmark_data.db" + db_path: str = "pilot/benchmark_meta_data/ant_icube_dev.db" table_mapping_file: str = "pilot/benchmark_meta_data/table_mapping.json" cache_expiry_days: int = 1 + repo_url: str = "https://github.com/inclusionAI/Falcon" + data_dir: str = "data/source" class BenchmarkDataManager(BaseComponent): @@ -41,56 +44,174 @@ class BenchmarkDataManager(BaseComponent): ): super().__init__(system_app) self._config = config or BenchmarkDataConfig() - self._http_session = None - self._db_conn = None + self._http_session: Optional[aiohttp.ClientSession] = None + self._connector: Optional[SQLiteConnector] = None self._table_mappings = self._load_mappings() self._lock = asyncio.Lock() - self.temp_dir = None + self.temp_dir: Optional[str] = None # Ensure directories exist os.makedirs(self._config.cache_dir, exist_ok=True) + db_dir = os.path.dirname(self._config.db_path) + if db_dir: + os.makedirs(db_dir, exist_ok=True) + self._startup_loaded: bool = False def init_app(self, system_app: SystemApp): """Initialize the AgentManager.""" self.system_app = system_app + async def async_after_start(self): + logger.info("BenchmarkDataManager: async_after_start.") + + try: + if not self._config.repo_url: + logger.info("BenchmarkDataManager: repo_url not set, skip auto load.") + return + + if self._startup_loaded: + logger.info("BenchmarkDataManager: already loaded on startup, skip.") + return + + logger.info( + f"BenchmarkDataManager: auto loading repo {self._config.repo_url} " + f"dir={self._config.data_dir}" + ) + await get_benchmark_manager(self.system_app).load_from_github( + repo_url=self._config.repo_url, data_dir=self._config.data_dir + ) + self._startup_loaded = True + logger.info("BenchmarkDataManager: auto load finished.") + except Exception as e: + logger.error(f"BenchmarkDataManager: auto load failed: {e}") + + async def async_before_stop(self): + try: + logger.info("BenchmarkDataManager: closing resources before stop...") + await self.close() + logger.info("BenchmarkDataManager: close done.") + except Exception as e: + logger.warning(f"BenchmarkDataManager: close failed: {e}") + async def __aenter__(self): self._http_session = aiohttp.ClientSession() + await self.init_connector() return self async def __aexit__(self, exc_type, exc_val, exc_tb): await self.close() + async def init_connector(self): + """Initialize SQLiteConnector""" + async with self._lock: + if not self._connector: + self._connector = SQLiteConnector.from_file_path(self._config.db_path) + + async def close_connector(self): + """Close SQLiteConnector""" + async with self._lock: + if self._connector: + try: + self._connector.close() + except Exception as e: + logger.warning(f"Close connector failed: {e}") + self._connector = None + async def close(self): """Clean up resources""" if self._http_session: await self._http_session.close() self._http_session = None - if self._db_conn: - self._db_conn.close() - self._db_conn = None + await self.close_connector() self._cleanup_temp_dir() - async def get_connection(self) -> sqlite3.Connection: - """Get database connection (thread-safe)""" - async with self._lock: - if not self._db_conn: - self._db_conn = sqlite3.connect(self._config.db_path) - return self._db_conn + async def _run_in_thread(self, func, *args, **kwargs): + """Run blocking function in thread to avoid blocking event loop""" + return await asyncio.to_thread(func, *args, **kwargs) + + def _sanitize_column_name(self, name: str) -> str: + if name is None: + return "" + name = str(name).strip().strip('"').strip("'") + invalid_chars = [ + "-", + " ", + ".", + ",", + ";", + ":", + "!", + "?", + "'", + '"', + "(", + ")", + "[", + "]", + "{", + "}", + "\t", + "\r", + "\n", + "\x00", + ] + for ch in invalid_chars: + name = name.replace(ch, "_") + while "__" in name: + name = name.replace("__", "_") + if name and not (name[0].isalpha() or name[0] == "_"): + name = "_" + name + return name.lower() + + def _sanitize_and_dedup_headers(self, headers: List[str]) -> List[str]: + sanitized: List[str] = [] + used: set = set() + for idx, h in enumerate(headers): + name = self._sanitize_column_name(h) + if not name: + name = f"col_{idx}" + base = name + k = 2 + while name in used or not name: + name = f"{base}_{k}" + k += 1 + used.add(name) + sanitized.append(name) + return sanitized + + # ========================================================== + + # 通用查询(阻塞实现,在线程池中调用) + def _query_blocking(self, sql: str, params: Optional[Dict[str, Any]] = None): + assert self._connector is not None, "Connector not initialized" + with self._connector.session_scope() as session: + cursor = session.execute(text(sql), params or {}) + rows = cursor.fetchall() + # SQLAlchemy 2.0: cursor.keys() 提供列名 + cols = list(cursor.keys()) + return cols, rows + + # 通用写入(阻塞实现,在线程池中调用) + def _execute_blocking(self, sql: str, params: Optional[Dict[str, Any]] = None): + assert self._connector is not None, "Connector not initialized" + with self._connector.session_scope() as session: + result = session.execute(text(sql), params or {}) + session.commit() + return result.rowcount async def query(self, query: str, params: tuple = ()) -> List[Dict]: """Execute query and return results as dict list""" - conn = await self.get_connection() - cursor = conn.cursor() - cursor.execute(query, params) - columns = [col[0] for col in cursor.description] - return [dict(zip(columns, row)) for row in cursor.fetchall()] + await self.init_connector() + cols, rows = await self._run_in_thread(self._query_blocking, query, params) + return [dict(zip(cols, row)) for row in rows] async def load_from_github( self, repo_url: str, data_dir: str = "data/source" ) -> Dict: """Main method to load data from GitHub repository""" try: + await self.init_connector() + # 1. Download or use cached repository repo_dir = await self._download_repo_contents(repo_url) @@ -106,31 +227,38 @@ class BenchmarkDataManager(BaseComponent): except Exception as e: logger.error(f"Import failed: {str(e)}") - raise + raise RuntimeError(f"Benchmark data loading failed: {e}") from e finally: self._cleanup_temp_dir() async def get_table_info(self) -> Dict: """Get metadata about all tables""" - conn = await self.get_connection() - cursor = conn.cursor() + await self.init_connector() + assert self._connector is not None - cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") - tables = cursor.fetchall() + def _work(): + with self._connector.session_scope() as session: + tables = session.execute( + text("SELECT name FROM sqlite_master WHERE type='table'") + ) + tables = [row[0] for row in tables.fetchall()] + result: Dict[str, Any] = {} + for table_name in tables: + row_count = session.execute( + text(f'SELECT COUNT(*) FROM "{table_name}"') + ).fetchone()[0] + columns = session.execute( + text(f'PRAGMA table_info("{table_name}")') + ).fetchall() + result[table_name] = { + "row_count": row_count, + "columns": [ + {"name": col[1], "type": col[2]} for col in columns + ], + } + return result - result = {} - for table in tables: - table_name = table[0] - cursor.execute(f"SELECT COUNT(*) FROM {table_name}") - row_count = cursor.fetchone()[0] - cursor.execute(f"PRAGMA table_info({table_name})") - columns = cursor.fetchall() - - result[table_name] = { - "row_count": row_count, - "columns": [{"name": col[1], "type": col[2]} for col in columns], - } - return result + return await self._run_in_thread(_work) def clear_cache(self): """Clear cached repository files""" @@ -214,6 +342,8 @@ class BenchmarkDataManager(BaseComponent): logger.info(f"Downloading from GitHub repo: {zip_url}") try: + if self._http_session is None: + self._http_session = aiohttp.ClientSession() async with self._http_session.get(zip_url) as response: response.raise_for_status() zip_path = os.path.join(self.temp_dir, "repo.zip") @@ -233,7 +363,7 @@ class BenchmarkDataManager(BaseComponent): except Exception as e: self._cleanup_temp_dir() - raise RuntimeError(f"Failed to download repository: {str(e)}") + raise RuntimeError(f"Failed to download repository: {str(e)}") from e def _get_cache_path(self, repo_url: str) -> str: """Get path to cached zip file""" @@ -288,8 +418,8 @@ class BenchmarkDataManager(BaseComponent): async def _import_to_database(self, csv_files: List[Dict]) -> Dict: """Import CSV data to SQLite""" - conn = await self.get_connection() - cursor = conn.cursor() + await self.init_connector() + assert self._connector is not None results = { "total_files": len(csv_files), "successful": 0, @@ -297,13 +427,13 @@ class BenchmarkDataManager(BaseComponent): "tables_created": [], } - for file_info in csv_files: + def _process_one_file(file_info: Dict) -> Tuple[bool, Optional[str]]: + table_name = "" try: path_parts = [p for p in file_info["rel_path"].split(os.sep) if p] table_name = "_".join(path_parts + [Path(file_info["file_name"]).stem]) table_name = self._sanitize_table_name(table_name) - # Try multiple encodings encodings = ["utf-8-sig", "utf-8", "latin-1", "iso-8859-1", "cp1252"] for encoding in encodings: @@ -311,187 +441,220 @@ class BenchmarkDataManager(BaseComponent): with open(file_info["full_path"], "r", encoding=encoding) as f: content = f.read() - # Handle empty files - if not content.strip(): - raise ValueError("File is empty") + if not content.strip(): + raise ValueError("File is empty") - # Replace problematic line breaks if needed - content = content.replace("\r\n", "\n").replace("\r", "\n") + content = content.replace("\r\n", "\n").replace("\r", "\n") + lines = [line for line in content.split("\n") if line.strip()] + if not lines: + raise ValueError("No data after normalization") - # Split into lines - lines = [ - line for line in content.split("\n") if line.strip() - ] + header_line = lines[0] + data_line = lines[1] if len(lines) > 1 else "" + + try: + sample_for_sniff = "\n".join(lines[:10]) + sniffer = csv.Sniffer() + try: + dialect = sniffer.sniff(sample_for_sniff) + except Exception: + + class _DefaultDialect(csv.Dialect): + delimiter = "," + quotechar = '"' + doublequote = True + skipinitialspace = False + lineterminator = "\n" + quoting = csv.QUOTE_MINIMAL + + dialect = _DefaultDialect() try: - header_line = lines[0] - data_line = lines[1] if len(lines) > 1 else "" + has_header = sniffer.has_header("\n".join(lines[:50])) + except Exception: + has_header = True - # Detect delimiter (comma, semicolon, tab) - sniffer = csv.Sniffer() - dialect = sniffer.sniff(header_line) - has_header = sniffer.has_header(content[:1024]) + header_row = ( + list(csv.reader([header_line], dialect))[0] + if header_line + else [] + ) + first_data_row = ( + list(csv.reader([data_line], dialect))[0] + if data_line + else [] + ) - if has_header: - headers = list(csv.reader([header_line], dialect))[ - 0 - ] - first_data_row = ( - list(csv.reader([data_line], dialect))[0] - if data_line - else [] - ) - else: - headers = list(csv.reader([header_line], dialect))[ - 0 - ] - first_data_row = headers # first line is data - headers = [f"col_{i}" for i in range(len(headers))] + if not has_header: + num_cols_guess = len(header_row) + headers = [f"col_{i}" for i in range(num_cols_guess)] + first_data_row = header_row + else: + headers = header_row - # Determine actual number of columns from data - actual_columns = ( - len(first_data_row) - if first_data_row - else len(headers) + num_cols = ( + len(first_data_row) if first_data_row else len(headers) + ) + + # no header + if not headers or all( + (not str(h).strip()) for h in headers + ): + headers = [f"col_{i}" for i in range(num_cols or 1)] + + headers = self._sanitize_and_dedup_headers(headers) + + if num_cols <= 0: + num_cols = len(headers) + headers = headers[:num_cols] + if not headers or any( + h is None or h == "" for h in headers + ): + raise csv.Error("Invalid headers after sanitization") + + create_sql = f''' + CREATE TABLE IF NOT EXISTS "{table_name}" ( + {", ".join([f'"{h}" TEXT' for h in headers])} ) + ''' + insert_sql = f''' + INSERT INTO "{table_name}" ({ + ", ".join([f'"{h}"' for h in headers]) + }) + VALUES ({ + ", ".join([":" + f"p{i}" for i in range(len(headers))]) + }) + ''' - # Create table with correct number of columns - create_sql = f""" - CREATE TABLE IF NOT EXISTS {table_name} ({ - ", ".join( - [ - f'"{h}" TEXT' - for h in headers[:actual_columns] - ] - ) - }) - """ - cursor.execute(create_sql) + with self._connector.session_scope() as session: + logger.debug( + f"Table: {table_name}, headers(final): {headers}" + ) + session.execute(text(create_sql)) - # Prepare insert statement - insert_sql = f""" - INSERT INTO {table_name} VALUES ({ - ", ".join(["?"] * actual_columns) - }) - """ - - # Process data - batch = [] reader = csv.reader(lines, dialect) if has_header: - next(reader) # skip header + next(reader, None) + batch_params: List[Dict[str, Any]] = [] for row in reader: - if not row: # skip empty rows + if not row: continue - - # Ensure row has correct number of columns - if len(row) != actual_columns: - if len(row) < actual_columns: - row += [None] * (actual_columns - len(row)) + if len(row) != len(headers): + if len(row) < len(headers): + row += [None] * (len(headers) - len(row)) else: - row = row[:actual_columns] + row = row[: len(headers)] + params = { + f"p{i}": (row[i] if i < len(row) else None) + for i in range(len(headers)) + } + batch_params.append(params) + if len(batch_params) >= 1000: + session.execute(text(insert_sql), batch_params) + batch_params = [] + if batch_params: + session.execute(text(insert_sql), batch_params) + session.commit() - batch.append(row) - if len(batch) >= 1000: - cursor.executemany(insert_sql, batch) - batch = [] + return True, table_name - if batch: - cursor.executemany(insert_sql, batch) - - results["successful"] += 1 - results["tables_created"].append(table_name) - break - - except csv.Error as e: - # Fallback for malformed CSV files - self._import_with_simple_split( - cursor, table_name, content, results, file_info - ) - break + except csv.Error: + self._import_with_simple_split_blocking(table_name, content) + return True, table_name except UnicodeDecodeError: continue except Exception as e: logger.warning(f"Error with encoding {encoding}: {str(e)}") continue - else: - # All encodings failed - try binary mode as last resort - try: - with open(file_info["full_path"], "rb") as f: - content = f.read().decode("ascii", errors="ignore") - if content.strip(): - self._import_with_simple_split( - cursor, table_name, content, results, file_info - ) - else: - raise ValueError("File is empty or unreadable") - except Exception as e: - results["failed"] += 1 - logger.error( - f"Failed to process {file_info['file_name']}: {str(e)}" - ) + + try: + with open(file_info["full_path"], "rb") as f: + content = f.read().decode("ascii", errors="ignore") + if content.strip(): + self._import_with_simple_split_blocking(table_name, content) + return True, table_name + else: + raise ValueError("File is empty or unreadable") + except Exception as e: + return ( + False, + f"Failed to process {file_info['file_name']}: {str(e)}", + ) except Exception as e: - results["failed"] += 1 - logger.error(f"Failed to process {file_info['full_path']}: {str(e)}") + return ( + False, + f"Failed to process {file_info.get('full_path', '')}: {str(e)}", + ) + + for file_info in csv_files: + ok, info = await self._run_in_thread(_process_one_file, file_info) + if ok: + results["successful"] += 1 + if info: + results["tables_created"].append(info) + else: + results["failed"] += 1 + logger.error(info) - self._db_conn.commit() return results - def _import_with_simple_split( - self, cursor, table_name, content, results, file_info - ): - """Fallback method for malformed CSV files""" - # Normalize line endings + def _import_with_simple_split_blocking(self, table_name: str, content: str): + """Fallback method for malformed CSV files (blocking, 使用 SQLAlchemy 执行)""" + assert self._connector is not None content = content.replace("\r\n", "\n").replace("\r", "\n") lines = [line for line in content.split("\n") if line.strip()] - if not lines: raise ValueError("No data found after cleaning") - # Determine delimiter first_line = lines[0] delimiter = "," if "," in first_line else "\t" if "\t" in first_line else ";" - # Process header - headers = first_line.split(delimiter) + raw_headers = first_line.split(delimiter) + headers = self._sanitize_and_dedup_headers(raw_headers) actual_columns = len(headers) - # Create table create_sql = f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - {", ".join([f"col_{i} TEXT" for i in range(actual_columns)])} - ) + CREATE TABLE IF NOT EXISTS "{table_name}" ( + {", ".join([f'"{h}" TEXT' for h in headers])} + ) """ - cursor.execute(create_sql) - # Prepare insert insert_sql = f""" - INSERT INTO {table_name} VALUES ({", ".join(["?"] * actual_columns)}) + INSERT INTO "{table_name}" ({", ".join([f'"{h}"' for h in headers])}) + VALUES ({", ".join([":" + f"p{i}" for i in range(actual_columns)])}) """ - # Process data - batch = [] - for line in lines[1:]: # skip header - row = line.split(delimiter) - if len(row) != actual_columns: - if len(row) < actual_columns: - row += [None] * (actual_columns - len(row)) - else: - row = row[:actual_columns] - batch.append(row) + with self._connector.session_scope() as session: + session.execute(text(create_sql)) + batch: List[Dict[str, Any]] = [] + for line in lines[1:]: + row = line.split(delimiter) + if len(row) != actual_columns: + if len(row) < actual_columns: + row += [None] * (actual_columns - len(row)) + else: + row = row[:actual_columns] + params = {f"p{i}": row[i] for i in range(actual_columns)} + batch.append(params) + if len(batch) >= 1000: + session.execute(text(insert_sql), batch) + batch = [] + if batch: + session.execute(text(insert_sql), batch) + session.commit() - if len(batch) >= 1000: - cursor.executemany(insert_sql, batch) - batch = [] + async def get_table_info_simple(self) -> List[str]: + """Return simplified table info: table(column1,column2,...)""" + await self.init_connector() + assert self._connector is not None - if batch: - cursor.executemany(insert_sql, batch) + def _work(): + return list(self._connector.table_simple_info()) - results["successful"] += 1 - results["tables_created"].append(table_name) + return await self._run_in_thread(_work) def _cleanup_temp_dir(self): """Clean up temporary directory"""