fix(core):Remove SQL comments before formatting the SQL (#2769)

This commit is contained in:
geebytes 2025-06-23 14:06:00 +08:00 committed by GitHub
parent c56f8810e6
commit 4cdbb22cb6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 233 additions and 9 deletions

View File

@ -201,5 +201,4 @@ class ChatExcel(BaseChat):
view_msg = self.stream_plugin_call(text_msg)
view_msg = final_output.gen_text_with_thinking(new_text=view_msg)
view_msg = view_msg.replace("\n", "\\n")
return final_output.text, view_msg

View File

@ -49,7 +49,10 @@ Constraints:
generated SQL and do not use column names that are not in the data structure
4. Prioritize using data analysis methods to answer. If the user's question does \
not involve data analysis content, you can answer based on your understanding
5. Convert the SQL part in the output content to: \
5. DuckDB processes timestamps using dedicated functions (like to_timestamp()) \
instead of direct CAST
6. Please note that comment lines should be on a separate line and not on the same
7. Convert the SQL part in the output content to: \
<api-call><name>[display method]</name><args><sql>\
[correct duckdb data analysis sql]</sql></args></api-call> \
format, refer to the return format requirements
@ -138,7 +141,9 @@ DuckDB SQL数据分析回答用户的问题。
3.SQL中需要使用的表名是: {table_name},请检查你生成的sql\
不要使用没在数据结构中的列名
4.优先使用数据分析的方式回答如果用户问题不涉及数据分析内容你可以按你的理解进行回答
5.输出内容中sql部分转换为
5.DuckDB 处理时间戳需通过专用函数 to_timestamp()而非直接 CAST
6.请注意注释行要单独一行不要放在 SQL 语句的同一行中
7.输出内容中sql部分转换为
<api-call><name>[数据显示方式]</name><args><sql>\
[正确的duckdb数据分析sql]</sql></args></api-call> \
这样的格式参考返回格式要求

View File

@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional, Union
from dbgpt._private.pydantic import BaseModel
from dbgpt.agent.core.schema import Status
from dbgpt.util.json_utils import serialize
from dbgpt.util.sql_utils import remove_sql_comments
from dbgpt.util.string_utils import extract_content, extract_content_open_ending
logger = logging.getLogger(__name__)
@ -105,10 +106,17 @@ class ApiCall:
all_context = all_context.replace(tag + api_context, api_context)
return all_context
def _remove_sql_comments(self, sql: str) -> str:
"""Remove SQL comments from the given SQL string."""
return remove_sql_comments(sql)
def _format_api_context(self, raw_api_context: str) -> str:
"""Format the API context."""
# Remove newline characters
if "<sql>" in raw_api_context:
# For cases involving SQL, remove comments first—otherwise,
# removing line breaks afterward may cause SQL statement errors.
raw_api_context = self._remove_sql_comments(raw_api_context)
raw_api_context = (
raw_api_context.replace("\\n", " ")
.replace("\n", " ")

View File

@ -6,6 +6,7 @@ from .parameter_utils import ( # noqa: F401
EnvArgumentParser,
ParameterDescription,
)
from .sql_utils import remove_sql_comments # noqa: F401
from .utils import get_gpu_memory, get_or_create_event_loop # noqa: F401
__ALL__ = [
@ -17,4 +18,5 @@ __ALL__ = [
"EnvArgumentParser",
"AppConfig",
"RegisterParameters",
"remove_sql_comments",
]

View File

@ -0,0 +1,10 @@
import re
def remove_sql_comments(sql: str) -> str:
"""Remove SQL comments from the given SQL string."""
# Remove single-line comments (--) and multi-line comments (/* ... */)
sql = re.sub(r"--.*?(\n|$)", "", sql)
sql = re.sub(r"/\*.*?\*/", "", sql, flags=re.DOTALL)
return sql

View File

@ -0,0 +1,200 @@
import pytest
from dbgpt.util.sql_utils import remove_sql_comments
class TestRemoveSqlComments:
"""Test cases for the remove_sql_comments method."""
def test_remove_single_line_comments_basic(self):
"""Test removing basic single-line comments."""
sql = "SELECT * FROM users -- This is a comment"
expected = "SELECT * FROM users "
result = remove_sql_comments(sql)
assert result == expected
def test_remove_single_line_comments_with_newline(self):
"""Test removing single-line comments followed by newline."""
sql = "SELECT * FROM users -- This is a comment\nWHERE id = 1"
expected = "SELECT * FROM users WHERE id = 1"
result = remove_sql_comments(sql)
assert result == expected
def test_remove_single_line_comments_at_end(self):
"""Test removing single-line comments at the end of string."""
sql = "SELECT * FROM users\n-- Final comment"
expected = "SELECT * FROM users\n"
result = remove_sql_comments(sql)
assert result == expected
def test_remove_multiple_single_line_comments(self):
"""Test removing multiple single-line comments."""
sql = """SELECT * FROM users -- Comment 1
WHERE id = 1 -- Comment 2
ORDER BY name -- Comment 3"""
expected = """SELECT * FROM users WHERE id = 1 ORDER BY name """ # noqa: E501
result = remove_sql_comments(sql)
assert result == expected
def test_remove_multiline_comments_basic(self):
"""Test removing basic multi-line comments."""
sql = "SELECT * FROM users /* This is a comment */ WHERE id = 1"
expected = "SELECT * FROM users WHERE id = 1"
result = remove_sql_comments(sql)
assert result == expected
def test_remove_multiline_comments_with_newlines(self):
"""Test removing multi-line comments spanning multiple lines."""
sql = """SELECT * FROM users
/* This is a
multi-line comment */
WHERE id = 1"""
expected = """SELECT * FROM users
WHERE id = 1"""
result = remove_sql_comments(sql)
assert result == expected
def test_remove_multiple_multiline_comments(self):
"""Test removing multiple multi-line comments."""
sql = "SELECT * /* comment 1 */ FROM users /* comment 2 */ WHERE id = 1"
expected = "SELECT * FROM users WHERE id = 1"
result = remove_sql_comments(sql)
assert result == expected
def test_remove_mixed_comments(self):
"""Test removing both single-line and multi-line comments."""
sql = """SELECT * FROM users /* multi-line comment */
WHERE id = 1 -- single line comment
/* another multi-line
comment */
ORDER BY name -- final comment"""
expected = """SELECT * FROM users
WHERE id = 1
ORDER BY name """
result = remove_sql_comments(sql)
assert result == expected
def test_no_comments_to_remove(self):
"""Test SQL with no comments."""
sql = "SELECT * FROM users WHERE id = 1 ORDER BY name"
expected = "SELECT * FROM users WHERE id = 1 ORDER BY name"
result = remove_sql_comments(sql)
assert result == expected
def test_empty_string(self):
"""Test with empty string."""
sql = ""
expected = ""
result = remove_sql_comments(sql)
assert result == expected
def test_only_comments(self):
"""Test string with only comments."""
sql = "-- This is only a comment"
expected = ""
result = remove_sql_comments(sql)
assert result == expected
def test_only_multiline_comments(self):
"""Test string with only multi-line comments."""
sql = "/* This is only a multi-line comment */"
expected = ""
result = remove_sql_comments(sql)
assert result == expected
def test_comments_with_special_characters(self):
"""Test comments containing special characters."""
sql = "SELECT * FROM users -- Comment with special chars: @#$%^&*()"
expected = "SELECT * FROM users "
result = remove_sql_comments(sql)
assert result == expected
def test_comments_with_sql_keywords(self):
"""Test comments containing SQL keywords."""
sql = "SELECT * FROM users -- SELECT * FROM another_table"
expected = "SELECT * FROM users "
result = remove_sql_comments(sql)
assert result == expected
def test_whitespace_handling(self):
"""Test whitespace handling around comments."""
sql = "SELECT * -- comment with spaces \n FROM users"
expected = "SELECT * FROM users"
result = remove_sql_comments(sql)
assert result == expected
def test_multiline_comment_at_start(self):
"""Test multi-line comment at the beginning."""
sql = "/* Initial comment */ SELECT * FROM users"
expected = " SELECT * FROM users"
result = remove_sql_comments(sql)
assert result == expected
def test_multiline_comment_at_end(self):
"""Test multi-line comment at the end."""
sql = "SELECT * FROM users /* Final comment */"
expected = "SELECT * FROM users "
result = remove_sql_comments(sql)
assert result == expected
def test_complex_sql_with_comments(self):
"""Test complex SQL statement with various comment types."""
sql = """
/* Query to get user information */
SELECT
u.id, -- User ID
u.name, /* User full name */
u.email -- Contact email
FROM users u
/* Join with posts table */
LEFT JOIN posts p ON u.id = p.user_id
WHERE u.active = 1 -- Only active users
/* Order by creation date */
ORDER BY u.created_at DESC
-- Limit results
LIMIT 10;
"""
result = remove_sql_comments(sql)
# Verify that all comments are removed
assert "--" not in result
assert "/*" not in result
assert "*/" not in result
# Verify that SQL keywords are preserved
assert "SELECT" in result
assert "FROM" in result
assert "WHERE" in result
assert "ORDER BY" in result
assert "LIMIT" in result
# 如果需要测试实际的类实例,可以使用以下参数化测试
@pytest.mark.parametrize(
"sql_input,expected_output",
[
("SELECT * FROM users", "SELECT * FROM users"),
("SELECT * FROM users -- comment", "SELECT * FROM users "),
("SELECT * /* comment */ FROM users", "SELECT * FROM users"),
("", ""),
("-- only comment", ""),
("/* only comment */", ""),
],
)
def test_remove_sql_comments_parametrized(sql_input, expected_output):
"""Parametrized test for common cases."""
# 创建测试实例
class MockClass:
def _remove_sql_comments(self, sql: str) -> str:
import re
sql = re.sub(r"--.*?(\n|$)", "", sql)
sql = re.sub(r"/\*.*?\*/", "", sql, flags=re.DOTALL)
return sql
instance = MockClass()
result = instance._remove_sql_comments(sql_input)
assert result == expected_output

View File

@ -173,14 +173,14 @@ class ElasticStore(VectorStoreBase):
connect_kwargs = {}
elasticsearch_vector_config = vector_store_config.to_dict()
self.uri = os.getenv(
"ELASTICSEARCH_URL", "localhost"
) or elasticsearch_vector_config.get("uri")
"ELASTICSEARCH_URL", None
) or elasticsearch_vector_config.get("uri", "localhost")
self.port = os.getenv(
"ELASTICSEARCH_PORT", "9200"
) or elasticsearch_vector_config.get("post")
"ELASTICSEARCH_PORT", None
) or elasticsearch_vector_config.get("post", "9200")
self.username = os.getenv(
"ELASTICSEARCH_USERNAME"
) or elasticsearch_vector_config.get("username")
) or elasticsearch_vector_config.get("user")
self.password = os.getenv(
"ELASTICSEARCH_PASSWORD"
) or elasticsearch_vector_config.get("password")