[BUG] SQL Injection through CVE Bypass in DB-GPT 0.7.0 (CVE-2024-10835 & CVE-2024-10901) (#2650)

Co-authored-by: nkoorty <amalyshau2002@gmail.com>
Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
This commit is contained in:
Gecko Security
2025-04-28 07:54:41 +02:00
committed by GitHub
parent b16c6793ec
commit bcb43266cf
2 changed files with 149 additions and 137 deletions

View File

@@ -2,7 +2,7 @@ import json
import logging import logging
import re import re
import time import time
from typing import Dict, List from typing import Dict, List, Tuple
from fastapi import APIRouter, Body, Depends from fastapi import APIRouter, Body, Depends
@@ -95,66 +95,109 @@ async def get_editor_sql(
return Result.failed(msg="not have sql!") return Result.failed(msg="not have sql!")
def sanitize_sql(sql: str, db_type: str = None) -> Tuple[bool, str, dict]:
"""Simple SQL sanitizer to prevent injection.
Returns:
Tuple of (is_safe, reason, params)
"""
# Normalize SQL (remove comments and excess whitespace)
sql = re.sub(r"/\*.*?\*/", " ", sql)
sql = re.sub(r"--.*?$", " ", sql, flags=re.MULTILINE)
sql = re.sub(r"\s+", " ", sql).strip()
# Block multiple statements
if re.search(r";\s*(?!--|\*/|$)", sql):
return False, "Multiple SQL statements are not allowed", {}
# Block dangerous operations for all databases
dangerous_patterns = [
r"(?i)INTO\s+(?:OUT|DUMP)FILE",
r"(?i)LOAD\s+DATA",
r"(?i)SYSTEM",
r"(?i)EXEC\s+",
r"(?i)SHELL\b",
r"(?i)DROP\s+DATABASE",
r"(?i)DROP\s+USER",
r"(?i)GRANT\s+",
r"(?i)REVOKE\s+",
r"(?i)ALTER\s+(USER|DATABASE)",
]
# Add DuckDB specific patterns
if db_type == "duckdb":
dangerous_patterns.extend(
[
r"(?i)COPY\b",
r"(?i)EXPORT\b",
r"(?i)IMPORT\b",
r"(?i)INSTALL\b",
r"(?i)READ_\w+\b",
r"(?i)WRITE_\w+\b",
r"(?i)\.EXECUTE\(",
r"(?i)PRAGMA\b",
]
)
for pattern in dangerous_patterns:
if re.search(pattern, sql):
return False, f"Operation not allowed: {pattern}", {}
# Allow SELECT, CREATE TABLE, INSERT, UPDATE, and DELETE operations
# We're no longer restricting to read-only operations
allowed_operations = re.match(
r"(?i)^\s*(SELECT|CREATE\s+TABLE|INSERT\s+INTO|UPDATE|DELETE\s+FROM|ALTER\s+TABLE)\b",
sql,
)
if not allowed_operations:
return (
False,
"Operation not supported. Only SELECT, CREATE TABLE, INSERT, UPDATE, "
"DELETE and ALTER TABLE operations are allowed",
{},
)
# Extract parameters (simplified)
params = {}
param_count = 0
# Extract string literals
def replace_string(match):
nonlocal param_count
param_name = f"param_{param_count}"
params[param_name] = match.group(1)
param_count += 1
return f":{param_name}"
# Replace string literals with parameters
parameterized_sql = re.sub(r"'([^']*)'", replace_string, sql)
return True, parameterized_sql, params
@router.post("/v1/editor/sql/run", response_model=Result[SqlRunData]) @router.post("/v1/editor/sql/run", response_model=Result[SqlRunData])
async def editor_sql_run(run_param: dict = Body()): async def editor_sql_run(run_param: dict = Body()):
logger.info(f"editor_sql_run:{run_param}") logger.info(f"editor_sql_run:{run_param}")
db_name = run_param["db_name"] db_name = run_param["db_name"]
sql = run_param["sql"] sql = run_param["sql"]
if not db_name and not sql: if not db_name and not sql:
return Result.failed(msg="SQL run param error") return Result.failed(msg="SQL run param error")
# Validate database type and prevent dangerous operations # Get database connection
conn = CFG.local_db_manager.get_connector(db_name) conn = CFG.local_db_manager.get_connector(db_name)
db_type = getattr(conn, "db_type", "").lower() db_type = getattr(conn, "db_type", "").lower()
# Block dangerous operations for DuckDB # Sanitize and parameterize the SQL query
if db_type == "duckdb": is_safe, result, params = sanitize_sql(sql, db_type)
# Block file operations and system commands if not is_safe:
dangerous_keywords = [ logger.warning(f"Blocked dangerous SQL: {sql}")
# File operations return Result.failed(msg=f"Operation not allowed: {result}")
"copy",
"export",
"import",
"load",
"install",
"read_",
"write_",
"save",
"from_",
"to_",
# System commands
"create_",
"drop_",
".execute(",
"system",
"shell",
# Additional DuckDB specific operations
"attach",
"detach",
"pragma",
"checkpoint",
"load_extension",
"unload_extension",
# File paths
"/'",
"'/'",
"\\",
"://",
]
sql_lower = sql.lower().replace(" ", "") # Remove spaces to prevent bypass
if any(keyword in sql_lower for keyword in dangerous_keywords):
logger.warning(f"Blocked dangerous SQL operation attempt: {sql}")
return Result.failed(msg="Operation not allowed for security reasons")
# Additional check for file path patterns
if re.search(r"['\"].*[/\\].*['\"]", sql):
logger.warning(f"Blocked file path in SQL: {sql}")
return Result.failed(msg="File operations not allowed")
try: try:
start_time = time.time() * 1000 start_time = time.time() * 1000
# Add timeout protection # Use the parameterized query and parameters
colunms, sql_result = conn.query_ex(sql, timeout=30) colunms, sql_result = conn.query_ex(result, params=params, timeout=30)
# Convert result type safely # Convert result type safely
sql_result = [ sql_result = [
tuple(str(x) if x is not None else None for x in row) for row in sql_result tuple(str(x) if x is not None else None for x in row) for row in sql_result
@@ -216,103 +259,57 @@ async def get_editor_chart_info(
@router.post("/v1/editor/chart/run", response_model=Result[ChartRunData]) @router.post("/v1/editor/chart/run", response_model=Result[ChartRunData])
async def editor_chart_run(run_param: dict = Body()): async def chart_run(run_param: dict = Body()):
logger.info(f"editor_chart_run:{run_param}") logger.info(f"chart_run:{run_param}")
db_name = run_param["db_name"] db_name = run_param["db_name"]
sql = run_param["sql"] sql = run_param["sql"]
chart_type = run_param["chart_type"] chart_type = run_param["chart_type"]
# Validate input parameters # Get database connection
if not db_name or not sql or not chart_type: db_conn = CFG.local_db_manager.get_connector(db_name)
return Result.failed("Required parameters missing") db_type = getattr(db_conn, "db_type", "").lower()
# Sanitize and parameterize the SQL query
is_safe, result, params = sanitize_sql(sql, db_type)
if not is_safe:
logger.warning(f"Blocked dangerous SQL: {sql}")
return Result.failed(msg=f"Operation not allowed: {result}")
try: try:
# Validate database type and prevent dangerous operations
db_conn = CFG.local_db_manager.get_connector(db_name)
db_type = getattr(db_conn, "db_type", "").lower()
# Block dangerous operations for DuckDB
if db_type == "duckdb":
# Block file operations and system commands
dangerous_keywords = [
# File operations
"copy",
"export",
"import",
"load",
"install",
"read_",
"write_",
"save",
"from_",
"to_",
# System commands
"create_",
"drop_",
".execute(",
"system",
"shell",
# Additional DuckDB specific operations
"attach",
"detach",
"pragma",
"checkpoint",
"load_extension",
"unload_extension",
# File paths
"/'",
"'/'",
"\\",
"://",
]
sql_lower = sql.lower().replace(" ", "") # Remove spaces to prevent bypass
if any(keyword in sql_lower for keyword in dangerous_keywords):
logger.warning(
f"Blocked dangerous SQL operation attempt in chart: {sql}"
)
return Result.failed(msg="Operation not allowed for security reasons")
# Additional check for file path patterns
if re.search(r"['\"].*[/\\].*['\"]", sql):
logger.warning(f"Blocked file path in chart SQL: {sql}")
return Result.failed(msg="File operations not allowed")
dashboard_data_loader: DashboardDataLoader = DashboardDataLoader()
start_time = time.time() * 1000 start_time = time.time() * 1000
# Use the parameterized query and parameters
# Execute query with timeout colunms, sql_result = db_conn.query_ex(result, params=params, timeout=30)
colunms, sql_result = db_conn.query_ex(sql, timeout=30) # Convert result type safely
sql_result = [
# Safely convert and process results tuple(str(x) if x is not None else None for x in row) for row in sql_result
field_names, chart_values = dashboard_data_loader.get_chart_values_by_data( ]
colunms,
[
tuple(str(x) if x is not None else None for x in row)
for row in sql_result
],
sql,
)
# Calculate execution time # Calculate execution time
end_time = time.time() * 1000 end_time = time.time() * 1000
sql_run_data: SqlRunData = SqlRunData( sql_run_data: SqlRunData = SqlRunData(
result_info="", result_info="",
run_cost=(end_time - start_time) / 1000, run_cost=(end_time - start_time) / 1000,
colunms=colunms, colunms=colunms,
values=[list(row) for row in sql_result], values=sql_result,
) )
return Result.succ(
ChartRunData( chart_values = []
sql_data=sql_run_data, chart_values=chart_values, chart_type=chart_type for i in range(len(sql_result)):
row = sql_result[i]
chart_values.append(
{
"name": row[0],
"type": "value",
"value": row[1] if len(row) > 1 else "0",
}
) )
chart_data: ChartRunData = ChartRunData(
sql_data=sql_run_data, chart_values=chart_values, chart_type=chart_type
) )
return Result.succ(chart_data)
except Exception as e: except Exception as e:
logger.exception("Chart sql run failed!") logger.error(f"chart_run exception: {str(e)}", exc_info=True)
sql_result = SqlRunData(result_info=str(e), run_cost=0, colunms=[], values=[]) return Result.failed(msg=str(e))
return Result.succ(
ChartRunData(sql_data=sql_result, chart_values=[], chart_type=chart_type)
)
@router.post("/v1/chart/editor/submit", response_model=Result[bool]) @router.post("/v1/chart/editor/submit", response_model=Result[bool])

View File

@@ -477,7 +477,11 @@ class RDBMSConnector(BaseConnector):
return self._query(sql) return self._query(sql)
def query_ex( def query_ex(
self, query: str, fetch: str = "all", timeout: Optional[float] = None self,
query: str,
params: Optional[Dict[str, Any]] = None,
fetch: str = "all",
timeout: Optional[float] = None,
) -> Tuple[List[str], Optional[List]]: ) -> Tuple[List[str], Optional[List]]:
"""Execute a SQL command and return the results with optional timeout. """Execute a SQL command and return the results with optional timeout.
@@ -485,6 +489,7 @@ class RDBMSConnector(BaseConnector):
Args: Args:
query (str): SQL query to run query (str): SQL query to run
params (Optional[dict]): Parameters for the query
fetch (str): fetch type, either 'all' or 'one' fetch (str): fetch type, either 'all' or 'one'
timeout (Optional[float]): Query timeout in seconds. If None, no timeout is timeout (Optional[float]): Query timeout in seconds. If None, no timeout is
applied. applied.
@@ -501,13 +506,21 @@ class RDBMSConnector(BaseConnector):
return [], None return [], None
query = self._format_sql(query) query = self._format_sql(query)
def _execute_query(session, sql_text): # Initialize params if None
cursor = session.execute(sql_text) if params is None:
params = {}
def _execute_query(session, sql_text, query_params):
cursor = session.execute(sql_text, query_params)
if cursor.returns_rows: if cursor.returns_rows:
if fetch == "all": if fetch == "all":
result = cursor.fetchall() result = cursor.fetchall()
elif fetch == "one": elif fetch == "one":
result = cursor.fetchone() result = cursor.fetchone()
if result:
result = [result]
else:
result = []
else: else:
raise ValueError("Fetch parameter must be either 'one' or 'all'") raise ValueError("Fetch parameter must be either 'one' or 'all'")
field_names = list(cursor.keys()) field_names = list(cursor.keys())
@@ -526,14 +539,14 @@ class RDBMSConnector(BaseConnector):
session.execute( session.execute(
text(f"SET SESSION MAX_EXECUTION_TIME = {mysql_timeout}") text(f"SET SESSION MAX_EXECUTION_TIME = {mysql_timeout}")
) )
return _execute_query(session, sql) return _execute_query(session, sql, params)
elif self.dialect == "postgresql": elif self.dialect == "postgresql":
# PostgreSQL: Set statement_timeout in milliseconds # PostgreSQL: Set statement_timeout in milliseconds
session.execute( session.execute(
text(f"SET statement_timeout = {int(timeout * 1000)}") text(f"SET statement_timeout = {int(timeout * 1000)}")
) )
return _execute_query(session, sql) return _execute_query(session, sql, params)
elif self.dialect == "oceanbase": elif self.dialect == "oceanbase":
# OceanBase: Set ob_query_timeout in microseconds # OceanBase: Set ob_query_timeout in microseconds
@@ -541,17 +554,19 @@ class RDBMSConnector(BaseConnector):
session.execute( session.execute(
text(f"SET SESSION ob_query_timeout = {ob_timeout}") text(f"SET SESSION ob_query_timeout = {ob_timeout}")
) )
return _execute_query(session, sql) return _execute_query(session, sql, params)
elif self.dialect == "mssql": elif self.dialect == "mssql":
# MSSQL: Use execution_options if supported by driver # MSSQL: Use execution_options if supported by driver
sql_with_timeout = sql.execution_options(timeout=int(timeout)) sql_with_timeout = sql.execution_options(timeout=int(timeout))
return _execute_query(session, sql_with_timeout) return _execute_query(session, sql_with_timeout, params)
elif self.dialect == "duckdb": elif self.dialect == "duckdb":
# DuckDB: Use ThreadPoolExecutor for timeout # DuckDB: Use ThreadPoolExecutor for timeout
with ThreadPoolExecutor(max_workers=1) as executor: with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(_execute_query, session, sql) future = executor.submit(
_execute_query, session, sql, params
)
try: try:
return future.result(timeout=timeout) return future.result(timeout=timeout)
except FutureTimeoutError: except FutureTimeoutError:
@@ -564,10 +579,10 @@ class RDBMSConnector(BaseConnector):
f"Timeout not supported for dialect: {self.dialect}, " f"Timeout not supported for dialect: {self.dialect}, "
"proceeding without timeout" "proceeding without timeout"
) )
return _execute_query(session, sql) return _execute_query(session, sql, params)
# No timeout specified, execute normally # No timeout specified, execute normally
return _execute_query(session, sql) return _execute_query(session, sql, params)
except SQLAlchemyError as e: except SQLAlchemyError as e:
if "timeout" in str(e).lower() or "timed out" in str(e).lower(): if "timeout" in str(e).lower() or "timed out" in str(e).lower():