mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-03 10:05:13 +00:00
[BUG] SQL Injection through CVE Bypass in DB-GPT 0.7.0 (CVE-2024-10835 & CVE-2024-10901) (#2650)
Co-authored-by: nkoorty <amalyshau2002@gmail.com> Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
This commit is contained in:
@@ -2,7 +2,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from typing import Dict, List
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
from fastapi import APIRouter, Body, Depends
|
from fastapi import APIRouter, Body, Depends
|
||||||
|
|
||||||
@@ -95,66 +95,109 @@ async def get_editor_sql(
|
|||||||
return Result.failed(msg="not have sql!")
|
return Result.failed(msg="not have sql!")
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_sql(sql: str, db_type: str = None) -> Tuple[bool, str, dict]:
|
||||||
|
"""Simple SQL sanitizer to prevent injection.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (is_safe, reason, params)
|
||||||
|
"""
|
||||||
|
# Normalize SQL (remove comments and excess whitespace)
|
||||||
|
sql = re.sub(r"/\*.*?\*/", " ", sql)
|
||||||
|
sql = re.sub(r"--.*?$", " ", sql, flags=re.MULTILINE)
|
||||||
|
sql = re.sub(r"\s+", " ", sql).strip()
|
||||||
|
|
||||||
|
# Block multiple statements
|
||||||
|
if re.search(r";\s*(?!--|\*/|$)", sql):
|
||||||
|
return False, "Multiple SQL statements are not allowed", {}
|
||||||
|
|
||||||
|
# Block dangerous operations for all databases
|
||||||
|
dangerous_patterns = [
|
||||||
|
r"(?i)INTO\s+(?:OUT|DUMP)FILE",
|
||||||
|
r"(?i)LOAD\s+DATA",
|
||||||
|
r"(?i)SYSTEM",
|
||||||
|
r"(?i)EXEC\s+",
|
||||||
|
r"(?i)SHELL\b",
|
||||||
|
r"(?i)DROP\s+DATABASE",
|
||||||
|
r"(?i)DROP\s+USER",
|
||||||
|
r"(?i)GRANT\s+",
|
||||||
|
r"(?i)REVOKE\s+",
|
||||||
|
r"(?i)ALTER\s+(USER|DATABASE)",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add DuckDB specific patterns
|
||||||
|
if db_type == "duckdb":
|
||||||
|
dangerous_patterns.extend(
|
||||||
|
[
|
||||||
|
r"(?i)COPY\b",
|
||||||
|
r"(?i)EXPORT\b",
|
||||||
|
r"(?i)IMPORT\b",
|
||||||
|
r"(?i)INSTALL\b",
|
||||||
|
r"(?i)READ_\w+\b",
|
||||||
|
r"(?i)WRITE_\w+\b",
|
||||||
|
r"(?i)\.EXECUTE\(",
|
||||||
|
r"(?i)PRAGMA\b",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
for pattern in dangerous_patterns:
|
||||||
|
if re.search(pattern, sql):
|
||||||
|
return False, f"Operation not allowed: {pattern}", {}
|
||||||
|
|
||||||
|
# Allow SELECT, CREATE TABLE, INSERT, UPDATE, and DELETE operations
|
||||||
|
# We're no longer restricting to read-only operations
|
||||||
|
allowed_operations = re.match(
|
||||||
|
r"(?i)^\s*(SELECT|CREATE\s+TABLE|INSERT\s+INTO|UPDATE|DELETE\s+FROM|ALTER\s+TABLE)\b",
|
||||||
|
sql,
|
||||||
|
)
|
||||||
|
if not allowed_operations:
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
"Operation not supported. Only SELECT, CREATE TABLE, INSERT, UPDATE, "
|
||||||
|
"DELETE and ALTER TABLE operations are allowed",
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract parameters (simplified)
|
||||||
|
params = {}
|
||||||
|
param_count = 0
|
||||||
|
|
||||||
|
# Extract string literals
|
||||||
|
def replace_string(match):
|
||||||
|
nonlocal param_count
|
||||||
|
param_name = f"param_{param_count}"
|
||||||
|
params[param_name] = match.group(1)
|
||||||
|
param_count += 1
|
||||||
|
return f":{param_name}"
|
||||||
|
|
||||||
|
# Replace string literals with parameters
|
||||||
|
parameterized_sql = re.sub(r"'([^']*)'", replace_string, sql)
|
||||||
|
|
||||||
|
return True, parameterized_sql, params
|
||||||
|
|
||||||
|
|
||||||
@router.post("/v1/editor/sql/run", response_model=Result[SqlRunData])
|
@router.post("/v1/editor/sql/run", response_model=Result[SqlRunData])
|
||||||
async def editor_sql_run(run_param: dict = Body()):
|
async def editor_sql_run(run_param: dict = Body()):
|
||||||
logger.info(f"editor_sql_run:{run_param}")
|
logger.info(f"editor_sql_run:{run_param}")
|
||||||
db_name = run_param["db_name"]
|
db_name = run_param["db_name"]
|
||||||
sql = run_param["sql"]
|
sql = run_param["sql"]
|
||||||
|
|
||||||
if not db_name and not sql:
|
if not db_name and not sql:
|
||||||
return Result.failed(msg="SQL run param error!")
|
return Result.failed(msg="SQL run param error!")
|
||||||
|
|
||||||
# Validate database type and prevent dangerous operations
|
# Get database connection
|
||||||
conn = CFG.local_db_manager.get_connector(db_name)
|
conn = CFG.local_db_manager.get_connector(db_name)
|
||||||
db_type = getattr(conn, "db_type", "").lower()
|
db_type = getattr(conn, "db_type", "").lower()
|
||||||
|
|
||||||
# Block dangerous operations for DuckDB
|
# Sanitize and parameterize the SQL query
|
||||||
if db_type == "duckdb":
|
is_safe, result, params = sanitize_sql(sql, db_type)
|
||||||
# Block file operations and system commands
|
if not is_safe:
|
||||||
dangerous_keywords = [
|
logger.warning(f"Blocked dangerous SQL: {sql}")
|
||||||
# File operations
|
return Result.failed(msg=f"Operation not allowed: {result}")
|
||||||
"copy",
|
|
||||||
"export",
|
|
||||||
"import",
|
|
||||||
"load",
|
|
||||||
"install",
|
|
||||||
"read_",
|
|
||||||
"write_",
|
|
||||||
"save",
|
|
||||||
"from_",
|
|
||||||
"to_",
|
|
||||||
# System commands
|
|
||||||
"create_",
|
|
||||||
"drop_",
|
|
||||||
".execute(",
|
|
||||||
"system",
|
|
||||||
"shell",
|
|
||||||
# Additional DuckDB specific operations
|
|
||||||
"attach",
|
|
||||||
"detach",
|
|
||||||
"pragma",
|
|
||||||
"checkpoint",
|
|
||||||
"load_extension",
|
|
||||||
"unload_extension",
|
|
||||||
# File paths
|
|
||||||
"/'",
|
|
||||||
"'/'",
|
|
||||||
"\\",
|
|
||||||
"://",
|
|
||||||
]
|
|
||||||
sql_lower = sql.lower().replace(" ", "") # Remove spaces to prevent bypass
|
|
||||||
if any(keyword in sql_lower for keyword in dangerous_keywords):
|
|
||||||
logger.warning(f"Blocked dangerous SQL operation attempt: {sql}")
|
|
||||||
return Result.failed(msg="Operation not allowed for security reasons")
|
|
||||||
|
|
||||||
# Additional check for file path patterns
|
|
||||||
if re.search(r"['\"].*[/\\].*['\"]", sql):
|
|
||||||
logger.warning(f"Blocked file path in SQL: {sql}")
|
|
||||||
return Result.failed(msg="File operations not allowed")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
start_time = time.time() * 1000
|
start_time = time.time() * 1000
|
||||||
# Add timeout protection
|
# Use the parameterized query and parameters
|
||||||
colunms, sql_result = conn.query_ex(sql, timeout=30)
|
colunms, sql_result = conn.query_ex(result, params=params, timeout=30)
|
||||||
# Convert result type safely
|
# Convert result type safely
|
||||||
sql_result = [
|
sql_result = [
|
||||||
tuple(str(x) if x is not None else None for x in row) for row in sql_result
|
tuple(str(x) if x is not None else None for x in row) for row in sql_result
|
||||||
@@ -216,103 +259,57 @@ async def get_editor_chart_info(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/v1/editor/chart/run", response_model=Result[ChartRunData])
|
@router.post("/v1/editor/chart/run", response_model=Result[ChartRunData])
|
||||||
async def editor_chart_run(run_param: dict = Body()):
|
async def chart_run(run_param: dict = Body()):
|
||||||
logger.info(f"editor_chart_run:{run_param}")
|
logger.info(f"chart_run:{run_param}")
|
||||||
db_name = run_param["db_name"]
|
db_name = run_param["db_name"]
|
||||||
sql = run_param["sql"]
|
sql = run_param["sql"]
|
||||||
chart_type = run_param["chart_type"]
|
chart_type = run_param["chart_type"]
|
||||||
|
|
||||||
# Validate input parameters
|
# Get database connection
|
||||||
if not db_name or not sql or not chart_type:
|
db_conn = CFG.local_db_manager.get_connector(db_name)
|
||||||
return Result.failed("Required parameters missing")
|
db_type = getattr(db_conn, "db_type", "").lower()
|
||||||
|
|
||||||
|
# Sanitize and parameterize the SQL query
|
||||||
|
is_safe, result, params = sanitize_sql(sql, db_type)
|
||||||
|
if not is_safe:
|
||||||
|
logger.warning(f"Blocked dangerous SQL: {sql}")
|
||||||
|
return Result.failed(msg=f"Operation not allowed: {result}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Validate database type and prevent dangerous operations
|
|
||||||
db_conn = CFG.local_db_manager.get_connector(db_name)
|
|
||||||
db_type = getattr(db_conn, "db_type", "").lower()
|
|
||||||
|
|
||||||
# Block dangerous operations for DuckDB
|
|
||||||
if db_type == "duckdb":
|
|
||||||
# Block file operations and system commands
|
|
||||||
dangerous_keywords = [
|
|
||||||
# File operations
|
|
||||||
"copy",
|
|
||||||
"export",
|
|
||||||
"import",
|
|
||||||
"load",
|
|
||||||
"install",
|
|
||||||
"read_",
|
|
||||||
"write_",
|
|
||||||
"save",
|
|
||||||
"from_",
|
|
||||||
"to_",
|
|
||||||
# System commands
|
|
||||||
"create_",
|
|
||||||
"drop_",
|
|
||||||
".execute(",
|
|
||||||
"system",
|
|
||||||
"shell",
|
|
||||||
# Additional DuckDB specific operations
|
|
||||||
"attach",
|
|
||||||
"detach",
|
|
||||||
"pragma",
|
|
||||||
"checkpoint",
|
|
||||||
"load_extension",
|
|
||||||
"unload_extension",
|
|
||||||
# File paths
|
|
||||||
"/'",
|
|
||||||
"'/'",
|
|
||||||
"\\",
|
|
||||||
"://",
|
|
||||||
]
|
|
||||||
sql_lower = sql.lower().replace(" ", "") # Remove spaces to prevent bypass
|
|
||||||
if any(keyword in sql_lower for keyword in dangerous_keywords):
|
|
||||||
logger.warning(
|
|
||||||
f"Blocked dangerous SQL operation attempt in chart: {sql}"
|
|
||||||
)
|
|
||||||
return Result.failed(msg="Operation not allowed for security reasons")
|
|
||||||
|
|
||||||
# Additional check for file path patterns
|
|
||||||
if re.search(r"['\"].*[/\\].*['\"]", sql):
|
|
||||||
logger.warning(f"Blocked file path in chart SQL: {sql}")
|
|
||||||
return Result.failed(msg="File operations not allowed")
|
|
||||||
|
|
||||||
dashboard_data_loader: DashboardDataLoader = DashboardDataLoader()
|
|
||||||
|
|
||||||
start_time = time.time() * 1000
|
start_time = time.time() * 1000
|
||||||
|
# Use the parameterized query and parameters
|
||||||
# Execute query with timeout
|
colunms, sql_result = db_conn.query_ex(result, params=params, timeout=30)
|
||||||
colunms, sql_result = db_conn.query_ex(sql, timeout=30)
|
# Convert result type safely
|
||||||
|
sql_result = [
|
||||||
# Safely convert and process results
|
tuple(str(x) if x is not None else None for x in row) for row in sql_result
|
||||||
field_names, chart_values = dashboard_data_loader.get_chart_values_by_data(
|
]
|
||||||
colunms,
|
|
||||||
[
|
|
||||||
tuple(str(x) if x is not None else None for x in row)
|
|
||||||
for row in sql_result
|
|
||||||
],
|
|
||||||
sql,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calculate execution time
|
# Calculate execution time
|
||||||
end_time = time.time() * 1000
|
end_time = time.time() * 1000
|
||||||
sql_run_data: SqlRunData = SqlRunData(
|
sql_run_data: SqlRunData = SqlRunData(
|
||||||
result_info="",
|
result_info="",
|
||||||
run_cost=(end_time - start_time) / 1000,
|
run_cost=(end_time - start_time) / 1000,
|
||||||
colunms=colunms,
|
colunms=colunms,
|
||||||
values=[list(row) for row in sql_result],
|
values=sql_result,
|
||||||
)
|
)
|
||||||
return Result.succ(
|
|
||||||
ChartRunData(
|
chart_values = []
|
||||||
sql_data=sql_run_data, chart_values=chart_values, chart_type=chart_type
|
for i in range(len(sql_result)):
|
||||||
|
row = sql_result[i]
|
||||||
|
chart_values.append(
|
||||||
|
{
|
||||||
|
"name": row[0],
|
||||||
|
"type": "value",
|
||||||
|
"value": row[1] if len(row) > 1 else "0",
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
chart_data: ChartRunData = ChartRunData(
|
||||||
|
sql_data=sql_run_data, chart_values=chart_values, chart_type=chart_type
|
||||||
)
|
)
|
||||||
|
return Result.succ(chart_data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Chart sql run failed!")
|
logger.error(f"chart_run exception: {str(e)}", exc_info=True)
|
||||||
sql_result = SqlRunData(result_info=str(e), run_cost=0, colunms=[], values=[])
|
return Result.failed(msg=str(e))
|
||||||
return Result.succ(
|
|
||||||
ChartRunData(sql_data=sql_result, chart_values=[], chart_type=chart_type)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/v1/chart/editor/submit", response_model=Result[bool])
|
@router.post("/v1/chart/editor/submit", response_model=Result[bool])
|
||||||
|
@@ -477,7 +477,11 @@ class RDBMSConnector(BaseConnector):
|
|||||||
return self._query(sql)
|
return self._query(sql)
|
||||||
|
|
||||||
def query_ex(
|
def query_ex(
|
||||||
self, query: str, fetch: str = "all", timeout: Optional[float] = None
|
self,
|
||||||
|
query: str,
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
fetch: str = "all",
|
||||||
|
timeout: Optional[float] = None,
|
||||||
) -> Tuple[List[str], Optional[List]]:
|
) -> Tuple[List[str], Optional[List]]:
|
||||||
"""Execute a SQL command and return the results with optional timeout.
|
"""Execute a SQL command and return the results with optional timeout.
|
||||||
|
|
||||||
@@ -485,6 +489,7 @@ class RDBMSConnector(BaseConnector):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
query (str): SQL query to run
|
query (str): SQL query to run
|
||||||
|
params (Optional[dict]): Parameters for the query
|
||||||
fetch (str): fetch type, either 'all' or 'one'
|
fetch (str): fetch type, either 'all' or 'one'
|
||||||
timeout (Optional[float]): Query timeout in seconds. If None, no timeout is
|
timeout (Optional[float]): Query timeout in seconds. If None, no timeout is
|
||||||
applied.
|
applied.
|
||||||
@@ -501,13 +506,21 @@ class RDBMSConnector(BaseConnector):
|
|||||||
return [], None
|
return [], None
|
||||||
query = self._format_sql(query)
|
query = self._format_sql(query)
|
||||||
|
|
||||||
def _execute_query(session, sql_text):
|
# Initialize params if None
|
||||||
cursor = session.execute(sql_text)
|
if params is None:
|
||||||
|
params = {}
|
||||||
|
|
||||||
|
def _execute_query(session, sql_text, query_params):
|
||||||
|
cursor = session.execute(sql_text, query_params)
|
||||||
if cursor.returns_rows:
|
if cursor.returns_rows:
|
||||||
if fetch == "all":
|
if fetch == "all":
|
||||||
result = cursor.fetchall()
|
result = cursor.fetchall()
|
||||||
elif fetch == "one":
|
elif fetch == "one":
|
||||||
result = cursor.fetchone()
|
result = cursor.fetchone()
|
||||||
|
if result:
|
||||||
|
result = [result]
|
||||||
|
else:
|
||||||
|
result = []
|
||||||
else:
|
else:
|
||||||
raise ValueError("Fetch parameter must be either 'one' or 'all'")
|
raise ValueError("Fetch parameter must be either 'one' or 'all'")
|
||||||
field_names = list(cursor.keys())
|
field_names = list(cursor.keys())
|
||||||
@@ -526,14 +539,14 @@ class RDBMSConnector(BaseConnector):
|
|||||||
session.execute(
|
session.execute(
|
||||||
text(f"SET SESSION MAX_EXECUTION_TIME = {mysql_timeout}")
|
text(f"SET SESSION MAX_EXECUTION_TIME = {mysql_timeout}")
|
||||||
)
|
)
|
||||||
return _execute_query(session, sql)
|
return _execute_query(session, sql, params)
|
||||||
|
|
||||||
elif self.dialect == "postgresql":
|
elif self.dialect == "postgresql":
|
||||||
# PostgreSQL: Set statement_timeout in milliseconds
|
# PostgreSQL: Set statement_timeout in milliseconds
|
||||||
session.execute(
|
session.execute(
|
||||||
text(f"SET statement_timeout = {int(timeout * 1000)}")
|
text(f"SET statement_timeout = {int(timeout * 1000)}")
|
||||||
)
|
)
|
||||||
return _execute_query(session, sql)
|
return _execute_query(session, sql, params)
|
||||||
|
|
||||||
elif self.dialect == "oceanbase":
|
elif self.dialect == "oceanbase":
|
||||||
# OceanBase: Set ob_query_timeout in microseconds
|
# OceanBase: Set ob_query_timeout in microseconds
|
||||||
@@ -541,17 +554,19 @@ class RDBMSConnector(BaseConnector):
|
|||||||
session.execute(
|
session.execute(
|
||||||
text(f"SET SESSION ob_query_timeout = {ob_timeout}")
|
text(f"SET SESSION ob_query_timeout = {ob_timeout}")
|
||||||
)
|
)
|
||||||
return _execute_query(session, sql)
|
return _execute_query(session, sql, params)
|
||||||
|
|
||||||
elif self.dialect == "mssql":
|
elif self.dialect == "mssql":
|
||||||
# MSSQL: Use execution_options if supported by driver
|
# MSSQL: Use execution_options if supported by driver
|
||||||
sql_with_timeout = sql.execution_options(timeout=int(timeout))
|
sql_with_timeout = sql.execution_options(timeout=int(timeout))
|
||||||
return _execute_query(session, sql_with_timeout)
|
return _execute_query(session, sql_with_timeout, params)
|
||||||
|
|
||||||
elif self.dialect == "duckdb":
|
elif self.dialect == "duckdb":
|
||||||
# DuckDB: Use ThreadPoolExecutor for timeout
|
# DuckDB: Use ThreadPoolExecutor for timeout
|
||||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||||
future = executor.submit(_execute_query, session, sql)
|
future = executor.submit(
|
||||||
|
_execute_query, session, sql, params
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
return future.result(timeout=timeout)
|
return future.result(timeout=timeout)
|
||||||
except FutureTimeoutError:
|
except FutureTimeoutError:
|
||||||
@@ -564,10 +579,10 @@ class RDBMSConnector(BaseConnector):
|
|||||||
f"Timeout not supported for dialect: {self.dialect}, "
|
f"Timeout not supported for dialect: {self.dialect}, "
|
||||||
"proceeding without timeout"
|
"proceeding without timeout"
|
||||||
)
|
)
|
||||||
return _execute_query(session, sql)
|
return _execute_query(session, sql, params)
|
||||||
|
|
||||||
# No timeout specified, execute normally
|
# No timeout specified, execute normally
|
||||||
return _execute_query(session, sql)
|
return _execute_query(session, sql, params)
|
||||||
|
|
||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
if "timeout" in str(e).lower() or "timed out" in str(e).lower():
|
if "timeout" in str(e).lower() or "timed out" in str(e).lower():
|
||||||
|
Reference in New Issue
Block a user