mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-31 08:33:36 +00:00
[BUG] SQL Injection through CVE Bypass in DB-GPT 0.7.0 (CVE-2024-10835 & CVE-2024-10901) (#2650)
Co-authored-by: nkoorty <amalyshau2002@gmail.com> Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
This commit is contained in:
@@ -2,7 +2,7 @@ import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from fastapi import APIRouter, Body, Depends
|
||||
|
||||
@@ -95,66 +95,109 @@ async def get_editor_sql(
|
||||
return Result.failed(msg="not have sql!")
|
||||
|
||||
|
||||
def sanitize_sql(sql: str, db_type: str = None) -> Tuple[bool, str, dict]:
|
||||
"""Simple SQL sanitizer to prevent injection.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_safe, reason, params)
|
||||
"""
|
||||
# Normalize SQL (remove comments and excess whitespace)
|
||||
sql = re.sub(r"/\*.*?\*/", " ", sql)
|
||||
sql = re.sub(r"--.*?$", " ", sql, flags=re.MULTILINE)
|
||||
sql = re.sub(r"\s+", " ", sql).strip()
|
||||
|
||||
# Block multiple statements
|
||||
if re.search(r";\s*(?!--|\*/|$)", sql):
|
||||
return False, "Multiple SQL statements are not allowed", {}
|
||||
|
||||
# Block dangerous operations for all databases
|
||||
dangerous_patterns = [
|
||||
r"(?i)INTO\s+(?:OUT|DUMP)FILE",
|
||||
r"(?i)LOAD\s+DATA",
|
||||
r"(?i)SYSTEM",
|
||||
r"(?i)EXEC\s+",
|
||||
r"(?i)SHELL\b",
|
||||
r"(?i)DROP\s+DATABASE",
|
||||
r"(?i)DROP\s+USER",
|
||||
r"(?i)GRANT\s+",
|
||||
r"(?i)REVOKE\s+",
|
||||
r"(?i)ALTER\s+(USER|DATABASE)",
|
||||
]
|
||||
|
||||
# Add DuckDB specific patterns
|
||||
if db_type == "duckdb":
|
||||
dangerous_patterns.extend(
|
||||
[
|
||||
r"(?i)COPY\b",
|
||||
r"(?i)EXPORT\b",
|
||||
r"(?i)IMPORT\b",
|
||||
r"(?i)INSTALL\b",
|
||||
r"(?i)READ_\w+\b",
|
||||
r"(?i)WRITE_\w+\b",
|
||||
r"(?i)\.EXECUTE\(",
|
||||
r"(?i)PRAGMA\b",
|
||||
]
|
||||
)
|
||||
|
||||
for pattern in dangerous_patterns:
|
||||
if re.search(pattern, sql):
|
||||
return False, f"Operation not allowed: {pattern}", {}
|
||||
|
||||
# Allow SELECT, CREATE TABLE, INSERT, UPDATE, and DELETE operations
|
||||
# We're no longer restricting to read-only operations
|
||||
allowed_operations = re.match(
|
||||
r"(?i)^\s*(SELECT|CREATE\s+TABLE|INSERT\s+INTO|UPDATE|DELETE\s+FROM|ALTER\s+TABLE)\b",
|
||||
sql,
|
||||
)
|
||||
if not allowed_operations:
|
||||
return (
|
||||
False,
|
||||
"Operation not supported. Only SELECT, CREATE TABLE, INSERT, UPDATE, "
|
||||
"DELETE and ALTER TABLE operations are allowed",
|
||||
{},
|
||||
)
|
||||
|
||||
# Extract parameters (simplified)
|
||||
params = {}
|
||||
param_count = 0
|
||||
|
||||
# Extract string literals
|
||||
def replace_string(match):
|
||||
nonlocal param_count
|
||||
param_name = f"param_{param_count}"
|
||||
params[param_name] = match.group(1)
|
||||
param_count += 1
|
||||
return f":{param_name}"
|
||||
|
||||
# Replace string literals with parameters
|
||||
parameterized_sql = re.sub(r"'([^']*)'", replace_string, sql)
|
||||
|
||||
return True, parameterized_sql, params
|
||||
|
||||
|
||||
@router.post("/v1/editor/sql/run", response_model=Result[SqlRunData])
|
||||
async def editor_sql_run(run_param: dict = Body()):
|
||||
logger.info(f"editor_sql_run:{run_param}")
|
||||
db_name = run_param["db_name"]
|
||||
sql = run_param["sql"]
|
||||
|
||||
if not db_name and not sql:
|
||||
return Result.failed(msg="SQL run param error!")
|
||||
|
||||
# Validate database type and prevent dangerous operations
|
||||
# Get database connection
|
||||
conn = CFG.local_db_manager.get_connector(db_name)
|
||||
db_type = getattr(conn, "db_type", "").lower()
|
||||
|
||||
# Block dangerous operations for DuckDB
|
||||
if db_type == "duckdb":
|
||||
# Block file operations and system commands
|
||||
dangerous_keywords = [
|
||||
# File operations
|
||||
"copy",
|
||||
"export",
|
||||
"import",
|
||||
"load",
|
||||
"install",
|
||||
"read_",
|
||||
"write_",
|
||||
"save",
|
||||
"from_",
|
||||
"to_",
|
||||
# System commands
|
||||
"create_",
|
||||
"drop_",
|
||||
".execute(",
|
||||
"system",
|
||||
"shell",
|
||||
# Additional DuckDB specific operations
|
||||
"attach",
|
||||
"detach",
|
||||
"pragma",
|
||||
"checkpoint",
|
||||
"load_extension",
|
||||
"unload_extension",
|
||||
# File paths
|
||||
"/'",
|
||||
"'/'",
|
||||
"\\",
|
||||
"://",
|
||||
]
|
||||
sql_lower = sql.lower().replace(" ", "") # Remove spaces to prevent bypass
|
||||
if any(keyword in sql_lower for keyword in dangerous_keywords):
|
||||
logger.warning(f"Blocked dangerous SQL operation attempt: {sql}")
|
||||
return Result.failed(msg="Operation not allowed for security reasons")
|
||||
|
||||
# Additional check for file path patterns
|
||||
if re.search(r"['\"].*[/\\].*['\"]", sql):
|
||||
logger.warning(f"Blocked file path in SQL: {sql}")
|
||||
return Result.failed(msg="File operations not allowed")
|
||||
# Sanitize and parameterize the SQL query
|
||||
is_safe, result, params = sanitize_sql(sql, db_type)
|
||||
if not is_safe:
|
||||
logger.warning(f"Blocked dangerous SQL: {sql}")
|
||||
return Result.failed(msg=f"Operation not allowed: {result}")
|
||||
|
||||
try:
|
||||
start_time = time.time() * 1000
|
||||
# Add timeout protection
|
||||
colunms, sql_result = conn.query_ex(sql, timeout=30)
|
||||
# Use the parameterized query and parameters
|
||||
colunms, sql_result = conn.query_ex(result, params=params, timeout=30)
|
||||
# Convert result type safely
|
||||
sql_result = [
|
||||
tuple(str(x) if x is not None else None for x in row) for row in sql_result
|
||||
@@ -216,103 +259,57 @@ async def get_editor_chart_info(
|
||||
|
||||
|
||||
@router.post("/v1/editor/chart/run", response_model=Result[ChartRunData])
|
||||
async def editor_chart_run(run_param: dict = Body()):
|
||||
logger.info(f"editor_chart_run:{run_param}")
|
||||
async def chart_run(run_param: dict = Body()):
|
||||
logger.info(f"chart_run:{run_param}")
|
||||
db_name = run_param["db_name"]
|
||||
sql = run_param["sql"]
|
||||
chart_type = run_param["chart_type"]
|
||||
|
||||
# Validate input parameters
|
||||
if not db_name or not sql or not chart_type:
|
||||
return Result.failed("Required parameters missing")
|
||||
# Get database connection
|
||||
db_conn = CFG.local_db_manager.get_connector(db_name)
|
||||
db_type = getattr(db_conn, "db_type", "").lower()
|
||||
|
||||
# Sanitize and parameterize the SQL query
|
||||
is_safe, result, params = sanitize_sql(sql, db_type)
|
||||
if not is_safe:
|
||||
logger.warning(f"Blocked dangerous SQL: {sql}")
|
||||
return Result.failed(msg=f"Operation not allowed: {result}")
|
||||
|
||||
try:
|
||||
# Validate database type and prevent dangerous operations
|
||||
db_conn = CFG.local_db_manager.get_connector(db_name)
|
||||
db_type = getattr(db_conn, "db_type", "").lower()
|
||||
|
||||
# Block dangerous operations for DuckDB
|
||||
if db_type == "duckdb":
|
||||
# Block file operations and system commands
|
||||
dangerous_keywords = [
|
||||
# File operations
|
||||
"copy",
|
||||
"export",
|
||||
"import",
|
||||
"load",
|
||||
"install",
|
||||
"read_",
|
||||
"write_",
|
||||
"save",
|
||||
"from_",
|
||||
"to_",
|
||||
# System commands
|
||||
"create_",
|
||||
"drop_",
|
||||
".execute(",
|
||||
"system",
|
||||
"shell",
|
||||
# Additional DuckDB specific operations
|
||||
"attach",
|
||||
"detach",
|
||||
"pragma",
|
||||
"checkpoint",
|
||||
"load_extension",
|
||||
"unload_extension",
|
||||
# File paths
|
||||
"/'",
|
||||
"'/'",
|
||||
"\\",
|
||||
"://",
|
||||
]
|
||||
sql_lower = sql.lower().replace(" ", "") # Remove spaces to prevent bypass
|
||||
if any(keyword in sql_lower for keyword in dangerous_keywords):
|
||||
logger.warning(
|
||||
f"Blocked dangerous SQL operation attempt in chart: {sql}"
|
||||
)
|
||||
return Result.failed(msg="Operation not allowed for security reasons")
|
||||
|
||||
# Additional check for file path patterns
|
||||
if re.search(r"['\"].*[/\\].*['\"]", sql):
|
||||
logger.warning(f"Blocked file path in chart SQL: {sql}")
|
||||
return Result.failed(msg="File operations not allowed")
|
||||
|
||||
dashboard_data_loader: DashboardDataLoader = DashboardDataLoader()
|
||||
|
||||
start_time = time.time() * 1000
|
||||
|
||||
# Execute query with timeout
|
||||
colunms, sql_result = db_conn.query_ex(sql, timeout=30)
|
||||
|
||||
# Safely convert and process results
|
||||
field_names, chart_values = dashboard_data_loader.get_chart_values_by_data(
|
||||
colunms,
|
||||
[
|
||||
tuple(str(x) if x is not None else None for x in row)
|
||||
for row in sql_result
|
||||
],
|
||||
sql,
|
||||
)
|
||||
|
||||
# Use the parameterized query and parameters
|
||||
colunms, sql_result = db_conn.query_ex(result, params=params, timeout=30)
|
||||
# Convert result type safely
|
||||
sql_result = [
|
||||
tuple(str(x) if x is not None else None for x in row) for row in sql_result
|
||||
]
|
||||
# Calculate execution time
|
||||
end_time = time.time() * 1000
|
||||
sql_run_data: SqlRunData = SqlRunData(
|
||||
result_info="",
|
||||
run_cost=(end_time - start_time) / 1000,
|
||||
colunms=colunms,
|
||||
values=[list(row) for row in sql_result],
|
||||
values=sql_result,
|
||||
)
|
||||
return Result.succ(
|
||||
ChartRunData(
|
||||
sql_data=sql_run_data, chart_values=chart_values, chart_type=chart_type
|
||||
|
||||
chart_values = []
|
||||
for i in range(len(sql_result)):
|
||||
row = sql_result[i]
|
||||
chart_values.append(
|
||||
{
|
||||
"name": row[0],
|
||||
"type": "value",
|
||||
"value": row[1] if len(row) > 1 else "0",
|
||||
}
|
||||
)
|
||||
|
||||
chart_data: ChartRunData = ChartRunData(
|
||||
sql_data=sql_run_data, chart_values=chart_values, chart_type=chart_type
|
||||
)
|
||||
return Result.succ(chart_data)
|
||||
except Exception as e:
|
||||
logger.exception("Chart sql run failed!")
|
||||
sql_result = SqlRunData(result_info=str(e), run_cost=0, colunms=[], values=[])
|
||||
return Result.succ(
|
||||
ChartRunData(sql_data=sql_result, chart_values=[], chart_type=chart_type)
|
||||
)
|
||||
logger.error(f"chart_run exception: {str(e)}", exc_info=True)
|
||||
return Result.failed(msg=str(e))
|
||||
|
||||
|
||||
@router.post("/v1/chart/editor/submit", response_model=Result[bool])
|
||||
|
@@ -477,7 +477,11 @@ class RDBMSConnector(BaseConnector):
|
||||
return self._query(sql)
|
||||
|
||||
def query_ex(
|
||||
self, query: str, fetch: str = "all", timeout: Optional[float] = None
|
||||
self,
|
||||
query: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
fetch: str = "all",
|
||||
timeout: Optional[float] = None,
|
||||
) -> Tuple[List[str], Optional[List]]:
|
||||
"""Execute a SQL command and return the results with optional timeout.
|
||||
|
||||
@@ -485,6 +489,7 @@ class RDBMSConnector(BaseConnector):
|
||||
|
||||
Args:
|
||||
query (str): SQL query to run
|
||||
params (Optional[dict]): Parameters for the query
|
||||
fetch (str): fetch type, either 'all' or 'one'
|
||||
timeout (Optional[float]): Query timeout in seconds. If None, no timeout is
|
||||
applied.
|
||||
@@ -501,13 +506,21 @@ class RDBMSConnector(BaseConnector):
|
||||
return [], None
|
||||
query = self._format_sql(query)
|
||||
|
||||
def _execute_query(session, sql_text):
|
||||
cursor = session.execute(sql_text)
|
||||
# Initialize params if None
|
||||
if params is None:
|
||||
params = {}
|
||||
|
||||
def _execute_query(session, sql_text, query_params):
|
||||
cursor = session.execute(sql_text, query_params)
|
||||
if cursor.returns_rows:
|
||||
if fetch == "all":
|
||||
result = cursor.fetchall()
|
||||
elif fetch == "one":
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
result = [result]
|
||||
else:
|
||||
result = []
|
||||
else:
|
||||
raise ValueError("Fetch parameter must be either 'one' or 'all'")
|
||||
field_names = list(cursor.keys())
|
||||
@@ -526,14 +539,14 @@ class RDBMSConnector(BaseConnector):
|
||||
session.execute(
|
||||
text(f"SET SESSION MAX_EXECUTION_TIME = {mysql_timeout}")
|
||||
)
|
||||
return _execute_query(session, sql)
|
||||
return _execute_query(session, sql, params)
|
||||
|
||||
elif self.dialect == "postgresql":
|
||||
# PostgreSQL: Set statement_timeout in milliseconds
|
||||
session.execute(
|
||||
text(f"SET statement_timeout = {int(timeout * 1000)}")
|
||||
)
|
||||
return _execute_query(session, sql)
|
||||
return _execute_query(session, sql, params)
|
||||
|
||||
elif self.dialect == "oceanbase":
|
||||
# OceanBase: Set ob_query_timeout in microseconds
|
||||
@@ -541,17 +554,19 @@ class RDBMSConnector(BaseConnector):
|
||||
session.execute(
|
||||
text(f"SET SESSION ob_query_timeout = {ob_timeout}")
|
||||
)
|
||||
return _execute_query(session, sql)
|
||||
return _execute_query(session, sql, params)
|
||||
|
||||
elif self.dialect == "mssql":
|
||||
# MSSQL: Use execution_options if supported by driver
|
||||
sql_with_timeout = sql.execution_options(timeout=int(timeout))
|
||||
return _execute_query(session, sql_with_timeout)
|
||||
return _execute_query(session, sql_with_timeout, params)
|
||||
|
||||
elif self.dialect == "duckdb":
|
||||
# DuckDB: Use ThreadPoolExecutor for timeout
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(_execute_query, session, sql)
|
||||
future = executor.submit(
|
||||
_execute_query, session, sql, params
|
||||
)
|
||||
try:
|
||||
return future.result(timeout=timeout)
|
||||
except FutureTimeoutError:
|
||||
@@ -564,10 +579,10 @@ class RDBMSConnector(BaseConnector):
|
||||
f"Timeout not supported for dialect: {self.dialect}, "
|
||||
"proceeding without timeout"
|
||||
)
|
||||
return _execute_query(session, sql)
|
||||
return _execute_query(session, sql, params)
|
||||
|
||||
# No timeout specified, execute normally
|
||||
return _execute_query(session, sql)
|
||||
return _execute_query(session, sql, params)
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
if "timeout" in str(e).lower() or "timed out" in str(e).lower():
|
||||
|
Reference in New Issue
Block a user