mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-31 15:47:05 +00:00
fix(core):Remove SQL comments before formatting the SQL (#2769)
This commit is contained in:
parent
c56f8810e6
commit
4cdbb22cb6
@ -201,5 +201,4 @@ class ChatExcel(BaseChat):
|
||||
view_msg = self.stream_plugin_call(text_msg)
|
||||
view_msg = final_output.gen_text_with_thinking(new_text=view_msg)
|
||||
view_msg = view_msg.replace("\n", "\\n")
|
||||
|
||||
return final_output.text, view_msg
|
||||
|
@ -49,7 +49,10 @@ Constraints:
|
||||
generated SQL and do not use column names that are not in the data structure
|
||||
4. Prioritize using data analysis methods to answer. If the user's question does \
|
||||
not involve data analysis content, you can answer based on your understanding
|
||||
5. Convert the SQL part in the output content to: \
|
||||
5. DuckDB processes timestamps using dedicated functions (like to_timestamp()) \
|
||||
instead of direct CAST
|
||||
6. Please note that comment lines should be on a separate line and not on the same
|
||||
7. Convert the SQL part in the output content to: \
|
||||
<api-call><name>[display method]</name><args><sql>\
|
||||
[correct duckdb data analysis sql]</sql></args></api-call> \
|
||||
format, refer to the return format requirements
|
||||
@ -138,7 +141,9 @@ DuckDB SQL数据分析回答用户的问题。
|
||||
3.SQL中需要使用的表名是: {table_name},请检查你生成的sql,\
|
||||
不要使用没在数据结构中的列名
|
||||
4.优先使用数据分析的方式回答,如果用户问题不涉及数据分析内容,你可以按你的理解进行回答
|
||||
5.输出内容中sql部分转换为:
|
||||
5.DuckDB 处理时间戳需通过专用函数(如 to_timestamp())而非直接 CAST
|
||||
6.请注意,注释行要单独一行,不要放在 SQL 语句的同一行中
|
||||
7.输出内容中sql部分转换为:
|
||||
<api-call><name>[数据显示方式]</name><args><sql>\
|
||||
[正确的duckdb数据分析sql]</sql></args></api-call> \
|
||||
这样的格式,参考返回格式要求
|
||||
|
@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional, Union
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
from dbgpt.agent.core.schema import Status
|
||||
from dbgpt.util.json_utils import serialize
|
||||
from dbgpt.util.sql_utils import remove_sql_comments
|
||||
from dbgpt.util.string_utils import extract_content, extract_content_open_ending
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -105,10 +106,17 @@ class ApiCall:
|
||||
all_context = all_context.replace(tag + api_context, api_context)
|
||||
return all_context
|
||||
|
||||
def _remove_sql_comments(self, sql: str) -> str:
|
||||
"""Remove SQL comments from the given SQL string."""
|
||||
return remove_sql_comments(sql)
|
||||
|
||||
def _format_api_context(self, raw_api_context: str) -> str:
|
||||
"""Format the API context."""
|
||||
# Remove newline characters
|
||||
|
||||
if "<sql>" in raw_api_context:
|
||||
# For cases involving SQL, remove comments first—otherwise,
|
||||
# removing line breaks afterward may cause SQL statement errors.
|
||||
raw_api_context = self._remove_sql_comments(raw_api_context)
|
||||
raw_api_context = (
|
||||
raw_api_context.replace("\\n", " ")
|
||||
.replace("\n", " ")
|
||||
|
@ -6,6 +6,7 @@ from .parameter_utils import ( # noqa: F401
|
||||
EnvArgumentParser,
|
||||
ParameterDescription,
|
||||
)
|
||||
from .sql_utils import remove_sql_comments # noqa: F401
|
||||
from .utils import get_gpu_memory, get_or_create_event_loop # noqa: F401
|
||||
|
||||
__ALL__ = [
|
||||
@ -17,4 +18,5 @@ __ALL__ = [
|
||||
"EnvArgumentParser",
|
||||
"AppConfig",
|
||||
"RegisterParameters",
|
||||
"remove_sql_comments",
|
||||
]
|
||||
|
10
packages/dbgpt-core/src/dbgpt/util/sql_utils.py
Normal file
10
packages/dbgpt-core/src/dbgpt/util/sql_utils.py
Normal file
@ -0,0 +1,10 @@
|
||||
import re
|
||||
|
||||
|
||||
def remove_sql_comments(sql: str) -> str:
|
||||
"""Remove SQL comments from the given SQL string."""
|
||||
|
||||
# Remove single-line comments (--) and multi-line comments (/* ... */)
|
||||
sql = re.sub(r"--.*?(\n|$)", "", sql)
|
||||
sql = re.sub(r"/\*.*?\*/", "", sql, flags=re.DOTALL)
|
||||
return sql
|
200
packages/dbgpt-core/src/dbgpt/util/tests/test_sql_utils.py
Normal file
200
packages/dbgpt-core/src/dbgpt/util/tests/test_sql_utils.py
Normal file
@ -0,0 +1,200 @@
|
||||
import pytest
|
||||
|
||||
from dbgpt.util.sql_utils import remove_sql_comments
|
||||
|
||||
|
||||
class TestRemoveSqlComments:
|
||||
"""Test cases for the remove_sql_comments method."""
|
||||
|
||||
def test_remove_single_line_comments_basic(self):
|
||||
"""Test removing basic single-line comments."""
|
||||
sql = "SELECT * FROM users -- This is a comment"
|
||||
expected = "SELECT * FROM users "
|
||||
result = remove_sql_comments(sql)
|
||||
assert result == expected
|
||||
|
||||
def test_remove_single_line_comments_with_newline(self):
|
||||
"""Test removing single-line comments followed by newline."""
|
||||
sql = "SELECT * FROM users -- This is a comment\nWHERE id = 1"
|
||||
expected = "SELECT * FROM users WHERE id = 1"
|
||||
result = remove_sql_comments(sql)
|
||||
assert result == expected
|
||||
|
||||
def test_remove_single_line_comments_at_end(self):
|
||||
"""Test removing single-line comments at the end of string."""
|
||||
sql = "SELECT * FROM users\n-- Final comment"
|
||||
expected = "SELECT * FROM users\n"
|
||||
result = remove_sql_comments(sql)
|
||||
assert result == expected
|
||||
|
||||
def test_remove_multiple_single_line_comments(self):
|
||||
"""Test removing multiple single-line comments."""
|
||||
sql = """SELECT * FROM users -- Comment 1
|
||||
WHERE id = 1 -- Comment 2
|
||||
ORDER BY name -- Comment 3"""
|
||||
expected = """SELECT * FROM users WHERE id = 1 ORDER BY name """ # noqa: E501
|
||||
result = remove_sql_comments(sql)
|
||||
assert result == expected
|
||||
|
||||
def test_remove_multiline_comments_basic(self):
|
||||
"""Test removing basic multi-line comments."""
|
||||
sql = "SELECT * FROM users /* This is a comment */ WHERE id = 1"
|
||||
expected = "SELECT * FROM users WHERE id = 1"
|
||||
result = remove_sql_comments(sql)
|
||||
assert result == expected
|
||||
|
||||
def test_remove_multiline_comments_with_newlines(self):
|
||||
"""Test removing multi-line comments spanning multiple lines."""
|
||||
sql = """SELECT * FROM users
|
||||
/* This is a
|
||||
multi-line comment */
|
||||
WHERE id = 1"""
|
||||
expected = """SELECT * FROM users
|
||||
|
||||
WHERE id = 1"""
|
||||
result = remove_sql_comments(sql)
|
||||
assert result == expected
|
||||
|
||||
def test_remove_multiple_multiline_comments(self):
|
||||
"""Test removing multiple multi-line comments."""
|
||||
sql = "SELECT * /* comment 1 */ FROM users /* comment 2 */ WHERE id = 1"
|
||||
expected = "SELECT * FROM users WHERE id = 1"
|
||||
result = remove_sql_comments(sql)
|
||||
assert result == expected
|
||||
|
||||
def test_remove_mixed_comments(self):
|
||||
"""Test removing both single-line and multi-line comments."""
|
||||
sql = """SELECT * FROM users /* multi-line comment */
|
||||
WHERE id = 1 -- single line comment
|
||||
/* another multi-line
|
||||
comment */
|
||||
ORDER BY name -- final comment"""
|
||||
expected = """SELECT * FROM users
|
||||
WHERE id = 1
|
||||
ORDER BY name """
|
||||
result = remove_sql_comments(sql)
|
||||
assert result == expected
|
||||
|
||||
def test_no_comments_to_remove(self):
|
||||
"""Test SQL with no comments."""
|
||||
sql = "SELECT * FROM users WHERE id = 1 ORDER BY name"
|
||||
expected = "SELECT * FROM users WHERE id = 1 ORDER BY name"
|
||||
result = remove_sql_comments(sql)
|
||||
assert result == expected
|
||||
|
||||
def test_empty_string(self):
|
||||
"""Test with empty string."""
|
||||
sql = ""
|
||||
expected = ""
|
||||
result = remove_sql_comments(sql)
|
||||
assert result == expected
|
||||
|
||||
def test_only_comments(self):
|
||||
"""Test string with only comments."""
|
||||
sql = "-- This is only a comment"
|
||||
expected = ""
|
||||
result = remove_sql_comments(sql)
|
||||
assert result == expected
|
||||
|
||||
def test_only_multiline_comments(self):
|
||||
"""Test string with only multi-line comments."""
|
||||
sql = "/* This is only a multi-line comment */"
|
||||
expected = ""
|
||||
result = remove_sql_comments(sql)
|
||||
assert result == expected
|
||||
|
||||
def test_comments_with_special_characters(self):
|
||||
"""Test comments containing special characters."""
|
||||
sql = "SELECT * FROM users -- Comment with special chars: @#$%^&*()"
|
||||
expected = "SELECT * FROM users "
|
||||
result = remove_sql_comments(sql)
|
||||
assert result == expected
|
||||
|
||||
def test_comments_with_sql_keywords(self):
|
||||
"""Test comments containing SQL keywords."""
|
||||
sql = "SELECT * FROM users -- SELECT * FROM another_table"
|
||||
expected = "SELECT * FROM users "
|
||||
result = remove_sql_comments(sql)
|
||||
assert result == expected
|
||||
|
||||
def test_whitespace_handling(self):
|
||||
"""Test whitespace handling around comments."""
|
||||
sql = "SELECT * -- comment with spaces \n FROM users"
|
||||
expected = "SELECT * FROM users"
|
||||
result = remove_sql_comments(sql)
|
||||
assert result == expected
|
||||
|
||||
def test_multiline_comment_at_start(self):
|
||||
"""Test multi-line comment at the beginning."""
|
||||
sql = "/* Initial comment */ SELECT * FROM users"
|
||||
expected = " SELECT * FROM users"
|
||||
result = remove_sql_comments(sql)
|
||||
assert result == expected
|
||||
|
||||
def test_multiline_comment_at_end(self):
|
||||
"""Test multi-line comment at the end."""
|
||||
sql = "SELECT * FROM users /* Final comment */"
|
||||
expected = "SELECT * FROM users "
|
||||
result = remove_sql_comments(sql)
|
||||
assert result == expected
|
||||
|
||||
def test_complex_sql_with_comments(self):
|
||||
"""Test complex SQL statement with various comment types."""
|
||||
sql = """
|
||||
/* Query to get user information */
|
||||
SELECT
|
||||
u.id, -- User ID
|
||||
u.name, /* User full name */
|
||||
u.email -- Contact email
|
||||
FROM users u
|
||||
/* Join with posts table */
|
||||
LEFT JOIN posts p ON u.id = p.user_id
|
||||
WHERE u.active = 1 -- Only active users
|
||||
/* Order by creation date */
|
||||
ORDER BY u.created_at DESC
|
||||
-- Limit results
|
||||
LIMIT 10;
|
||||
"""
|
||||
|
||||
result = remove_sql_comments(sql)
|
||||
|
||||
# Verify that all comments are removed
|
||||
assert "--" not in result
|
||||
assert "/*" not in result
|
||||
assert "*/" not in result
|
||||
|
||||
# Verify that SQL keywords are preserved
|
||||
assert "SELECT" in result
|
||||
assert "FROM" in result
|
||||
assert "WHERE" in result
|
||||
assert "ORDER BY" in result
|
||||
assert "LIMIT" in result
|
||||
|
||||
|
||||
# 如果需要测试实际的类实例,可以使用以下参数化测试
|
||||
@pytest.mark.parametrize(
|
||||
"sql_input,expected_output",
|
||||
[
|
||||
("SELECT * FROM users", "SELECT * FROM users"),
|
||||
("SELECT * FROM users -- comment", "SELECT * FROM users "),
|
||||
("SELECT * /* comment */ FROM users", "SELECT * FROM users"),
|
||||
("", ""),
|
||||
("-- only comment", ""),
|
||||
("/* only comment */", ""),
|
||||
],
|
||||
)
|
||||
def test_remove_sql_comments_parametrized(sql_input, expected_output):
|
||||
"""Parametrized test for common cases."""
|
||||
|
||||
# 创建测试实例
|
||||
class MockClass:
|
||||
def _remove_sql_comments(self, sql: str) -> str:
|
||||
import re
|
||||
|
||||
sql = re.sub(r"--.*?(\n|$)", "", sql)
|
||||
sql = re.sub(r"/\*.*?\*/", "", sql, flags=re.DOTALL)
|
||||
return sql
|
||||
|
||||
instance = MockClass()
|
||||
result = instance._remove_sql_comments(sql_input)
|
||||
assert result == expected_output
|
@ -173,14 +173,14 @@ class ElasticStore(VectorStoreBase):
|
||||
connect_kwargs = {}
|
||||
elasticsearch_vector_config = vector_store_config.to_dict()
|
||||
self.uri = os.getenv(
|
||||
"ELASTICSEARCH_URL", "localhost"
|
||||
) or elasticsearch_vector_config.get("uri")
|
||||
"ELASTICSEARCH_URL", None
|
||||
) or elasticsearch_vector_config.get("uri", "localhost")
|
||||
self.port = os.getenv(
|
||||
"ELASTICSEARCH_PORT", "9200"
|
||||
) or elasticsearch_vector_config.get("post")
|
||||
"ELASTICSEARCH_PORT", None
|
||||
) or elasticsearch_vector_config.get("post", "9200")
|
||||
self.username = os.getenv(
|
||||
"ELASTICSEARCH_USERNAME"
|
||||
) or elasticsearch_vector_config.get("username")
|
||||
) or elasticsearch_vector_config.get("user")
|
||||
self.password = os.getenv(
|
||||
"ELASTICSEARCH_PASSWORD"
|
||||
) or elasticsearch_vector_config.get("password")
|
||||
|
Loading…
Reference in New Issue
Block a user