[BUG] SQL Injection through CVE Bypass in DB-GPT 0.7.0 (CVE-2024-10835 & CVE-2024-10901) (#2650)

Co-authored-by: nkoorty <amalyshau2002@gmail.com>
Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
This commit is contained in:
Gecko Security
2025-04-28 07:54:41 +02:00
committed by GitHub
parent b16c6793ec
commit bcb43266cf
2 changed files with 149 additions and 137 deletions

View File

@@ -2,7 +2,7 @@ import json
import logging
import re
import time
from typing import Dict, List
from typing import Dict, List, Tuple
from fastapi import APIRouter, Body, Depends
@@ -95,66 +95,109 @@ async def get_editor_sql(
return Result.failed(msg="not have sql!")
def sanitize_sql(sql: str, db_type: str = None) -> Tuple[bool, str, dict]:
"""Simple SQL sanitizer to prevent injection.
Returns:
Tuple of (is_safe, reason, params)
"""
# Normalize SQL (remove comments and excess whitespace)
sql = re.sub(r"/\*.*?\*/", " ", sql)
sql = re.sub(r"--.*?$", " ", sql, flags=re.MULTILINE)
sql = re.sub(r"\s+", " ", sql).strip()
# Block multiple statements
if re.search(r";\s*(?!--|\*/|$)", sql):
return False, "Multiple SQL statements are not allowed", {}
# Block dangerous operations for all databases
dangerous_patterns = [
r"(?i)INTO\s+(?:OUT|DUMP)FILE",
r"(?i)LOAD\s+DATA",
r"(?i)SYSTEM",
r"(?i)EXEC\s+",
r"(?i)SHELL\b",
r"(?i)DROP\s+DATABASE",
r"(?i)DROP\s+USER",
r"(?i)GRANT\s+",
r"(?i)REVOKE\s+",
r"(?i)ALTER\s+(USER|DATABASE)",
]
# Add DuckDB specific patterns
if db_type == "duckdb":
dangerous_patterns.extend(
[
r"(?i)COPY\b",
r"(?i)EXPORT\b",
r"(?i)IMPORT\b",
r"(?i)INSTALL\b",
r"(?i)READ_\w+\b",
r"(?i)WRITE_\w+\b",
r"(?i)\.EXECUTE\(",
r"(?i)PRAGMA\b",
]
)
for pattern in dangerous_patterns:
if re.search(pattern, sql):
return False, f"Operation not allowed: {pattern}", {}
# Allow SELECT, CREATE TABLE, INSERT, UPDATE, and DELETE operations
# We're no longer restricting to read-only operations
allowed_operations = re.match(
r"(?i)^\s*(SELECT|CREATE\s+TABLE|INSERT\s+INTO|UPDATE|DELETE\s+FROM|ALTER\s+TABLE)\b",
sql,
)
if not allowed_operations:
return (
False,
"Operation not supported. Only SELECT, CREATE TABLE, INSERT, UPDATE, "
"DELETE and ALTER TABLE operations are allowed",
{},
)
# Extract parameters (simplified)
params = {}
param_count = 0
# Extract string literals
def replace_string(match):
nonlocal param_count
param_name = f"param_{param_count}"
params[param_name] = match.group(1)
param_count += 1
return f":{param_name}"
# Replace string literals with parameters
parameterized_sql = re.sub(r"'([^']*)'", replace_string, sql)
return True, parameterized_sql, params
@router.post("/v1/editor/sql/run", response_model=Result[SqlRunData])
async def editor_sql_run(run_param: dict = Body()):
logger.info(f"editor_sql_run:{run_param}")
db_name = run_param["db_name"]
sql = run_param["sql"]
if not db_name and not sql:
return Result.failed(msg="SQL run param error")
# Validate database type and prevent dangerous operations
# Get database connection
conn = CFG.local_db_manager.get_connector(db_name)
db_type = getattr(conn, "db_type", "").lower()
# Block dangerous operations for DuckDB
if db_type == "duckdb":
# Block file operations and system commands
dangerous_keywords = [
# File operations
"copy",
"export",
"import",
"load",
"install",
"read_",
"write_",
"save",
"from_",
"to_",
# System commands
"create_",
"drop_",
".execute(",
"system",
"shell",
# Additional DuckDB specific operations
"attach",
"detach",
"pragma",
"checkpoint",
"load_extension",
"unload_extension",
# File paths
"/'",
"'/'",
"\\",
"://",
]
sql_lower = sql.lower().replace(" ", "") # Remove spaces to prevent bypass
if any(keyword in sql_lower for keyword in dangerous_keywords):
logger.warning(f"Blocked dangerous SQL operation attempt: {sql}")
return Result.failed(msg="Operation not allowed for security reasons")
# Additional check for file path patterns
if re.search(r"['\"].*[/\\].*['\"]", sql):
logger.warning(f"Blocked file path in SQL: {sql}")
return Result.failed(msg="File operations not allowed")
# Sanitize and parameterize the SQL query
is_safe, result, params = sanitize_sql(sql, db_type)
if not is_safe:
logger.warning(f"Blocked dangerous SQL: {sql}")
return Result.failed(msg=f"Operation not allowed: {result}")
try:
start_time = time.time() * 1000
# Add timeout protection
colunms, sql_result = conn.query_ex(sql, timeout=30)
# Use the parameterized query and parameters
colunms, sql_result = conn.query_ex(result, params=params, timeout=30)
# Convert result type safely
sql_result = [
tuple(str(x) if x is not None else None for x in row) for row in sql_result
@@ -216,103 +259,57 @@ async def get_editor_chart_info(
@router.post("/v1/editor/chart/run", response_model=Result[ChartRunData])
async def editor_chart_run(run_param: dict = Body()):
logger.info(f"editor_chart_run:{run_param}")
async def chart_run(run_param: dict = Body()):
logger.info(f"chart_run:{run_param}")
db_name = run_param["db_name"]
sql = run_param["sql"]
chart_type = run_param["chart_type"]
# Validate input parameters
if not db_name or not sql or not chart_type:
return Result.failed("Required parameters missing")
# Get database connection
db_conn = CFG.local_db_manager.get_connector(db_name)
db_type = getattr(db_conn, "db_type", "").lower()
# Sanitize and parameterize the SQL query
is_safe, result, params = sanitize_sql(sql, db_type)
if not is_safe:
logger.warning(f"Blocked dangerous SQL: {sql}")
return Result.failed(msg=f"Operation not allowed: {result}")
try:
# Validate database type and prevent dangerous operations
db_conn = CFG.local_db_manager.get_connector(db_name)
db_type = getattr(db_conn, "db_type", "").lower()
# Block dangerous operations for DuckDB
if db_type == "duckdb":
# Block file operations and system commands
dangerous_keywords = [
# File operations
"copy",
"export",
"import",
"load",
"install",
"read_",
"write_",
"save",
"from_",
"to_",
# System commands
"create_",
"drop_",
".execute(",
"system",
"shell",
# Additional DuckDB specific operations
"attach",
"detach",
"pragma",
"checkpoint",
"load_extension",
"unload_extension",
# File paths
"/'",
"'/'",
"\\",
"://",
]
sql_lower = sql.lower().replace(" ", "") # Remove spaces to prevent bypass
if any(keyword in sql_lower for keyword in dangerous_keywords):
logger.warning(
f"Blocked dangerous SQL operation attempt in chart: {sql}"
)
return Result.failed(msg="Operation not allowed for security reasons")
# Additional check for file path patterns
if re.search(r"['\"].*[/\\].*['\"]", sql):
logger.warning(f"Blocked file path in chart SQL: {sql}")
return Result.failed(msg="File operations not allowed")
dashboard_data_loader: DashboardDataLoader = DashboardDataLoader()
start_time = time.time() * 1000
# Execute query with timeout
colunms, sql_result = db_conn.query_ex(sql, timeout=30)
# Safely convert and process results
field_names, chart_values = dashboard_data_loader.get_chart_values_by_data(
colunms,
[
tuple(str(x) if x is not None else None for x in row)
for row in sql_result
],
sql,
)
# Use the parameterized query and parameters
colunms, sql_result = db_conn.query_ex(result, params=params, timeout=30)
# Convert result type safely
sql_result = [
tuple(str(x) if x is not None else None for x in row) for row in sql_result
]
# Calculate execution time
end_time = time.time() * 1000
sql_run_data: SqlRunData = SqlRunData(
result_info="",
run_cost=(end_time - start_time) / 1000,
colunms=colunms,
values=[list(row) for row in sql_result],
values=sql_result,
)
return Result.succ(
ChartRunData(
sql_data=sql_run_data, chart_values=chart_values, chart_type=chart_type
chart_values = []
for i in range(len(sql_result)):
row = sql_result[i]
chart_values.append(
{
"name": row[0],
"type": "value",
"value": row[1] if len(row) > 1 else "0",
}
)
chart_data: ChartRunData = ChartRunData(
sql_data=sql_run_data, chart_values=chart_values, chart_type=chart_type
)
return Result.succ(chart_data)
except Exception as e:
logger.exception("Chart sql run failed!")
sql_result = SqlRunData(result_info=str(e), run_cost=0, colunms=[], values=[])
return Result.succ(
ChartRunData(sql_data=sql_result, chart_values=[], chart_type=chart_type)
)
logger.error(f"chart_run exception: {str(e)}", exc_info=True)
return Result.failed(msg=str(e))
@router.post("/v1/chart/editor/submit", response_model=Result[bool])

View File

@@ -477,7 +477,11 @@ class RDBMSConnector(BaseConnector):
return self._query(sql)
def query_ex(
self, query: str, fetch: str = "all", timeout: Optional[float] = None
self,
query: str,
params: Optional[Dict[str, Any]] = None,
fetch: str = "all",
timeout: Optional[float] = None,
) -> Tuple[List[str], Optional[List]]:
"""Execute a SQL command and return the results with optional timeout.
@@ -485,6 +489,7 @@ class RDBMSConnector(BaseConnector):
Args:
query (str): SQL query to run
params (Optional[dict]): Parameters for the query
fetch (str): fetch type, either 'all' or 'one'
timeout (Optional[float]): Query timeout in seconds. If None, no timeout is
applied.
@@ -501,13 +506,21 @@ class RDBMSConnector(BaseConnector):
return [], None
query = self._format_sql(query)
def _execute_query(session, sql_text):
cursor = session.execute(sql_text)
# Initialize params if None
if params is None:
params = {}
def _execute_query(session, sql_text, query_params):
cursor = session.execute(sql_text, query_params)
if cursor.returns_rows:
if fetch == "all":
result = cursor.fetchall()
elif fetch == "one":
result = cursor.fetchone()
if result:
result = [result]
else:
result = []
else:
raise ValueError("Fetch parameter must be either 'one' or 'all'")
field_names = list(cursor.keys())
@@ -526,14 +539,14 @@ class RDBMSConnector(BaseConnector):
session.execute(
text(f"SET SESSION MAX_EXECUTION_TIME = {mysql_timeout}")
)
return _execute_query(session, sql)
return _execute_query(session, sql, params)
elif self.dialect == "postgresql":
# PostgreSQL: Set statement_timeout in milliseconds
session.execute(
text(f"SET statement_timeout = {int(timeout * 1000)}")
)
return _execute_query(session, sql)
return _execute_query(session, sql, params)
elif self.dialect == "oceanbase":
# OceanBase: Set ob_query_timeout in microseconds
@@ -541,17 +554,19 @@ class RDBMSConnector(BaseConnector):
session.execute(
text(f"SET SESSION ob_query_timeout = {ob_timeout}")
)
return _execute_query(session, sql)
return _execute_query(session, sql, params)
elif self.dialect == "mssql":
# MSSQL: Use execution_options if supported by driver
sql_with_timeout = sql.execution_options(timeout=int(timeout))
return _execute_query(session, sql_with_timeout)
return _execute_query(session, sql_with_timeout, params)
elif self.dialect == "duckdb":
# DuckDB: Use ThreadPoolExecutor for timeout
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(_execute_query, session, sql)
future = executor.submit(
_execute_query, session, sql, params
)
try:
return future.result(timeout=timeout)
except FutureTimeoutError:
@@ -564,10 +579,10 @@ class RDBMSConnector(BaseConnector):
f"Timeout not supported for dialect: {self.dialect}, "
"proceeding without timeout"
)
return _execute_query(session, sql)
return _execute_query(session, sql, params)
# No timeout specified, execute normally
return _execute_query(session, sql)
return _execute_query(session, sql, params)
except SQLAlchemyError as e:
if "timeout" in str(e).lower() or "timed out" in str(e).lower():