diff --git a/packages/dbgpt-app/src/dbgpt_app/openapi/api_v1/editor/api_editor_v1.py b/packages/dbgpt-app/src/dbgpt_app/openapi/api_v1/editor/api_editor_v1.py index 693b43088..2d4855ed5 100644 --- a/packages/dbgpt-app/src/dbgpt_app/openapi/api_v1/editor/api_editor_v1.py +++ b/packages/dbgpt-app/src/dbgpt_app/openapi/api_v1/editor/api_editor_v1.py @@ -2,7 +2,7 @@ import json import logging import re import time -from typing import Dict, List +from typing import Dict, List, Tuple from fastapi import APIRouter, Body, Depends @@ -95,66 +95,109 @@ async def get_editor_sql( return Result.failed(msg="not have sql!") +def sanitize_sql(sql: str, db_type: str = None) -> Tuple[bool, str, dict]: + """Simple SQL sanitizer to prevent injection. + + Returns: + Tuple of (is_safe, reason, params) + """ + # Normalize SQL (remove comments and excess whitespace) + sql = re.sub(r"/\*.*?\*/", " ", sql) + sql = re.sub(r"--.*?$", " ", sql, flags=re.MULTILINE) + sql = re.sub(r"\s+", " ", sql).strip() + + # Block multiple statements + if re.search(r";\s*(?!--|\*/|$)", sql): + return False, "Multiple SQL statements are not allowed", {} + + # Block dangerous operations for all databases + dangerous_patterns = [ + r"(?i)INTO\s+(?:OUT|DUMP)FILE", + r"(?i)LOAD\s+DATA", + r"(?i)SYSTEM", + r"(?i)EXEC\s+", + r"(?i)SHELL\b", + r"(?i)DROP\s+DATABASE", + r"(?i)DROP\s+USER", + r"(?i)GRANT\s+", + r"(?i)REVOKE\s+", + r"(?i)ALTER\s+(USER|DATABASE)", + ] + + # Add DuckDB specific patterns + if db_type == "duckdb": + dangerous_patterns.extend( + [ + r"(?i)COPY\b", + r"(?i)EXPORT\b", + r"(?i)IMPORT\b", + r"(?i)INSTALL\b", + r"(?i)READ_\w+\b", + r"(?i)WRITE_\w+\b", + r"(?i)\.EXECUTE\(", + r"(?i)PRAGMA\b", + ] + ) + + for pattern in dangerous_patterns: + if re.search(pattern, sql): + return False, f"Operation not allowed: {pattern}", {} + + # Allow SELECT, CREATE TABLE, INSERT, UPDATE, and DELETE operations + # We're no longer restricting to read-only operations + allowed_operations = re.match( + r"(?i)^\s*(SELECT|CREATE\s+TABLE|INSERT\s+INTO|UPDATE|DELETE\s+FROM|ALTER\s+TABLE)\b", + sql, + ) + if not allowed_operations: + return ( + False, + "Operation not supported. Only SELECT, CREATE TABLE, INSERT, UPDATE, " + "DELETE and ALTER TABLE operations are allowed", + {}, + ) + + # Extract parameters (simplified) + params = {} + param_count = 0 + + # Extract string literals + def replace_string(match): + nonlocal param_count + param_name = f"param_{param_count}" + params[param_name] = match.group(1) + param_count += 1 + return f":{param_name}" + + # Replace string literals with parameters + parameterized_sql = re.sub(r"'([^']*)'", replace_string, sql) + + return True, parameterized_sql, params + + @router.post("/v1/editor/sql/run", response_model=Result[SqlRunData]) async def editor_sql_run(run_param: dict = Body()): logger.info(f"editor_sql_run:{run_param}") db_name = run_param["db_name"] sql = run_param["sql"] + if not db_name and not sql: return Result.failed(msg="SQL run param error!") - # Validate database type and prevent dangerous operations + # Get database connection conn = CFG.local_db_manager.get_connector(db_name) db_type = getattr(conn, "db_type", "").lower() - # Block dangerous operations for DuckDB - if db_type == "duckdb": - # Block file operations and system commands - dangerous_keywords = [ - # File operations - "copy", - "export", - "import", - "load", - "install", - "read_", - "write_", - "save", - "from_", - "to_", - # System commands - "create_", - "drop_", - ".execute(", - "system", - "shell", - # Additional DuckDB specific operations - "attach", - "detach", - "pragma", - "checkpoint", - "load_extension", - "unload_extension", - # File paths - "/'", - "'/'", - "\\", - "://", - ] - sql_lower = sql.lower().replace(" ", "") # Remove spaces to prevent bypass - if any(keyword in sql_lower for keyword in dangerous_keywords): - logger.warning(f"Blocked dangerous SQL operation attempt: {sql}") - return Result.failed(msg="Operation not allowed for security reasons") - - # Additional check for file path patterns - if re.search(r"['\"].*[/\\].*['\"]", sql): - logger.warning(f"Blocked file path in SQL: {sql}") - return Result.failed(msg="File operations not allowed") + # Sanitize and parameterize the SQL query + is_safe, result, params = sanitize_sql(sql, db_type) + if not is_safe: + logger.warning(f"Blocked dangerous SQL: {sql}") + return Result.failed(msg=f"Operation not allowed: {result}") try: start_time = time.time() * 1000 - # Add timeout protection - colunms, sql_result = conn.query_ex(sql, timeout=30) + # Use the parameterized query and parameters + colunms, sql_result = conn.query_ex(result, params=params, timeout=30) # Convert result type safely sql_result = [ tuple(str(x) if x is not None else None for x in row) for row in sql_result @@ -216,103 +259,57 @@ async def get_editor_chart_info( @router.post("/v1/editor/chart/run", response_model=Result[ChartRunData]) -async def editor_chart_run(run_param: dict = Body()): - logger.info(f"editor_chart_run:{run_param}") +async def chart_run(run_param: dict = Body()): + logger.info(f"chart_run:{run_param}") db_name = run_param["db_name"] sql = run_param["sql"] chart_type = run_param["chart_type"] - # Validate input parameters - if not db_name or not sql or not chart_type: - return Result.failed("Required parameters missing") + # Get database connection + db_conn = CFG.local_db_manager.get_connector(db_name) + db_type = getattr(db_conn, "db_type", "").lower() + + # Sanitize and parameterize the SQL query + is_safe, result, params = sanitize_sql(sql, db_type) + if not is_safe: + logger.warning(f"Blocked dangerous SQL: {sql}") + return Result.failed(msg=f"Operation not allowed: {result}") try: - # Validate database type and prevent dangerous operations - db_conn = CFG.local_db_manager.get_connector(db_name) - db_type = getattr(db_conn, "db_type", "").lower() - - # Block dangerous operations for DuckDB - if db_type == "duckdb": - # Block file operations and system commands - dangerous_keywords = [ - # File operations - "copy", - "export", - "import", - "load", - "install", - "read_", - "write_", - "save", - "from_", - "to_", - # System commands - "create_", - "drop_", - ".execute(", - "system", - "shell", - # Additional DuckDB specific operations - "attach", - "detach", - "pragma", - "checkpoint", - "load_extension", - "unload_extension", - # File paths - "/'", - "'/'", - "\\", - "://", - ] - sql_lower = sql.lower().replace(" ", "") # Remove spaces to prevent bypass - if any(keyword in sql_lower for keyword in dangerous_keywords): - logger.warning( - f"Blocked dangerous SQL operation attempt in chart: {sql}" - ) - return Result.failed(msg="Operation not allowed for security reasons") - - # Additional check for file path patterns - if re.search(r"['\"].*[/\\].*['\"]", sql): - logger.warning(f"Blocked file path in chart SQL: {sql}") - return Result.failed(msg="File operations not allowed") - - dashboard_data_loader: DashboardDataLoader = DashboardDataLoader() - start_time = time.time() * 1000 - - # Execute query with timeout - colunms, sql_result = db_conn.query_ex(sql, timeout=30) - - # Safely convert and process results - field_names, chart_values = dashboard_data_loader.get_chart_values_by_data( - colunms, - [ - tuple(str(x) if x is not None else None for x in row) - for row in sql_result - ], - sql, - ) - + # Use the parameterized query and parameters + colunms, sql_result = db_conn.query_ex(result, params=params, timeout=30) + # Convert result type safely + sql_result = [ + tuple(str(x) if x is not None else None for x in row) for row in sql_result + ] # Calculate execution time end_time = time.time() * 1000 sql_run_data: SqlRunData = SqlRunData( result_info="", run_cost=(end_time - start_time) / 1000, colunms=colunms, - values=[list(row) for row in sql_result], + values=sql_result, ) - return Result.succ( - ChartRunData( - sql_data=sql_run_data, chart_values=chart_values, chart_type=chart_type + + chart_values = [] + for i in range(len(sql_result)): + row = sql_result[i] + chart_values.append( + { + "name": row[0], + "type": "value", + "value": row[1] if len(row) > 1 else "0", + } ) + + chart_data: ChartRunData = ChartRunData( + sql_data=sql_run_data, chart_values=chart_values, chart_type=chart_type ) + return Result.succ(chart_data) except Exception as e: - logger.exception("Chart sql run failed!") - sql_result = SqlRunData(result_info=str(e), run_cost=0, colunms=[], values=[]) - return Result.succ( - ChartRunData(sql_data=sql_result, chart_values=[], chart_type=chart_type) - ) + logger.error(f"chart_run exception: {str(e)}", exc_info=True) + return Result.failed(msg=str(e)) @router.post("/v1/chart/editor/submit", response_model=Result[bool]) diff --git a/packages/dbgpt-core/src/dbgpt/datasource/rdbms/base.py b/packages/dbgpt-core/src/dbgpt/datasource/rdbms/base.py index 8bea873f5..c30233e12 100644 --- a/packages/dbgpt-core/src/dbgpt/datasource/rdbms/base.py +++ b/packages/dbgpt-core/src/dbgpt/datasource/rdbms/base.py @@ -477,7 +477,11 @@ class RDBMSConnector(BaseConnector): return self._query(sql) def query_ex( - self, query: str, fetch: str = "all", timeout: Optional[float] = None + self, + query: str, + params: Optional[Dict[str, Any]] = None, + fetch: str = "all", + timeout: Optional[float] = None, ) -> Tuple[List[str], Optional[List]]: """Execute a SQL command and return the results with optional timeout. @@ -485,6 +489,7 @@ class RDBMSConnector(BaseConnector): Args: query (str): SQL query to run + params (Optional[dict]): Parameters for the query fetch (str): fetch type, either 'all' or 'one' timeout (Optional[float]): Query timeout in seconds. If None, no timeout is applied. @@ -501,13 +506,21 @@ class RDBMSConnector(BaseConnector): return [], None query = self._format_sql(query) - def _execute_query(session, sql_text): - cursor = session.execute(sql_text) + # Initialize params if None + if params is None: + params = {} + + def _execute_query(session, sql_text, query_params): + cursor = session.execute(sql_text, query_params) if cursor.returns_rows: if fetch == "all": result = cursor.fetchall() elif fetch == "one": result = cursor.fetchone() + if result: + result = [result] + else: + result = [] else: raise ValueError("Fetch parameter must be either 'one' or 'all'") field_names = list(cursor.keys()) @@ -526,14 +539,14 @@ class RDBMSConnector(BaseConnector): session.execute( text(f"SET SESSION MAX_EXECUTION_TIME = {mysql_timeout}") ) - return _execute_query(session, sql) + return _execute_query(session, sql, params) elif self.dialect == "postgresql": # PostgreSQL: Set statement_timeout in milliseconds session.execute( text(f"SET statement_timeout = {int(timeout * 1000)}") ) - return _execute_query(session, sql) + return _execute_query(session, sql, params) elif self.dialect == "oceanbase": # OceanBase: Set ob_query_timeout in microseconds @@ -541,17 +554,19 @@ class RDBMSConnector(BaseConnector): session.execute( text(f"SET SESSION ob_query_timeout = {ob_timeout}") ) - return _execute_query(session, sql) + return _execute_query(session, sql, params) elif self.dialect == "mssql": # MSSQL: Use execution_options if supported by driver sql_with_timeout = sql.execution_options(timeout=int(timeout)) - return _execute_query(session, sql_with_timeout) + return _execute_query(session, sql_with_timeout, params) elif self.dialect == "duckdb": # DuckDB: Use ThreadPoolExecutor for timeout with ThreadPoolExecutor(max_workers=1) as executor: - future = executor.submit(_execute_query, session, sql) + future = executor.submit( + _execute_query, session, sql, params + ) try: return future.result(timeout=timeout) except FutureTimeoutError: @@ -564,10 +579,10 @@ class RDBMSConnector(BaseConnector): f"Timeout not supported for dialect: {self.dialect}, " "proceeding without timeout" ) - return _execute_query(session, sql) + return _execute_query(session, sql, params) # No timeout specified, execute normally - return _execute_query(session, sql) + return _execute_query(session, sql, params) except SQLAlchemyError as e: if "timeout" in str(e).lower() or "timed out" in str(e).lower():