From 4cdbb22cb655276d3723b23c6410d8e3f0bf9421 Mon Sep 17 00:00:00 2001 From: geebytes Date: Mon, 23 Jun 2025 14:06:00 +0800 Subject: [PATCH] fix(core):Remove SQL comments before formatting the SQL (#2769) --- .../chat_excel/excel_analyze/chat.py | 1 - .../chat_excel/excel_analyze/prompt.py | 9 +- .../src/dbgpt/agent/util/api_call.py | 10 +- .../dbgpt-core/src/dbgpt/util/__init__.py | 2 + .../dbgpt-core/src/dbgpt/util/sql_utils.py | 10 + .../src/dbgpt/util/tests/test_sql_utils.py | 200 ++++++++++++++++++ .../storage/vector_store/elastic_store.py | 10 +- 7 files changed, 233 insertions(+), 9 deletions(-) create mode 100644 packages/dbgpt-core/src/dbgpt/util/sql_utils.py create mode 100644 packages/dbgpt-core/src/dbgpt/util/tests/test_sql_utils.py diff --git a/packages/dbgpt-app/src/dbgpt_app/scene/chat_data/chat_excel/excel_analyze/chat.py b/packages/dbgpt-app/src/dbgpt_app/scene/chat_data/chat_excel/excel_analyze/chat.py index 263a6e917..71e9bfa47 100644 --- a/packages/dbgpt-app/src/dbgpt_app/scene/chat_data/chat_excel/excel_analyze/chat.py +++ b/packages/dbgpt-app/src/dbgpt_app/scene/chat_data/chat_excel/excel_analyze/chat.py @@ -201,5 +201,4 @@ class ChatExcel(BaseChat): view_msg = self.stream_plugin_call(text_msg) view_msg = final_output.gen_text_with_thinking(new_text=view_msg) view_msg = view_msg.replace("\n", "\\n") - return final_output.text, view_msg diff --git a/packages/dbgpt-app/src/dbgpt_app/scene/chat_data/chat_excel/excel_analyze/prompt.py b/packages/dbgpt-app/src/dbgpt_app/scene/chat_data/chat_excel/excel_analyze/prompt.py index e2529db24..7186cdf76 100644 --- a/packages/dbgpt-app/src/dbgpt_app/scene/chat_data/chat_excel/excel_analyze/prompt.py +++ b/packages/dbgpt-app/src/dbgpt_app/scene/chat_data/chat_excel/excel_analyze/prompt.py @@ -49,7 +49,10 @@ Constraints: generated SQL and do not use column names that are not in the data structure 4. Prioritize using data analysis methods to answer. If the user's question does \ not involve data analysis content, you can answer based on your understanding - 5. Convert the SQL part in the output content to: \ + 5. DuckDB processes timestamps using dedicated functions (like to_timestamp()) \ + instead of direct CAST + 6. Please note that comment lines should be on a separate line and not on the same + 7. Convert the SQL part in the output content to: \ [display method]\ [correct duckdb data analysis sql] \ format, refer to the return format requirements @@ -138,7 +141,9 @@ DuckDB SQL数据分析回答用户的问题。 3.SQL中需要使用的表名是: {table_name},请检查你生成的sql,\ 不要使用没在数据结构中的列名 4.优先使用数据分析的方式回答,如果用户问题不涉及数据分析内容,你可以按你的理解进行回答 - 5.输出内容中sql部分转换为: + 5.DuckDB 处理时间戳需通过专用函数(如 to_timestamp())而非直接 CAST + 6.请注意,注释行要单独一行,不要放在 SQL 语句的同一行中 + 7.输出内容中sql部分转换为: [数据显示方式]\ [正确的duckdb数据分析sql] \ 这样的格式,参考返回格式要求 diff --git a/packages/dbgpt-core/src/dbgpt/agent/util/api_call.py b/packages/dbgpt-core/src/dbgpt/agent/util/api_call.py index d7e537745..9cd284b26 100644 --- a/packages/dbgpt-core/src/dbgpt/agent/util/api_call.py +++ b/packages/dbgpt-core/src/dbgpt/agent/util/api_call.py @@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional, Union from dbgpt._private.pydantic import BaseModel from dbgpt.agent.core.schema import Status from dbgpt.util.json_utils import serialize +from dbgpt.util.sql_utils import remove_sql_comments from dbgpt.util.string_utils import extract_content, extract_content_open_ending logger = logging.getLogger(__name__) @@ -105,10 +106,17 @@ class ApiCall: all_context = all_context.replace(tag + api_context, api_context) return all_context + def _remove_sql_comments(self, sql: str) -> str: + """Remove SQL comments from the given SQL string.""" + return remove_sql_comments(sql) + def _format_api_context(self, raw_api_context: str) -> str: """Format the API context.""" # Remove newline characters - + if "" in raw_api_context: + # For cases involving SQL, remove comments first—otherwise, + # removing line breaks afterward may cause SQL statement errors. + raw_api_context = self._remove_sql_comments(raw_api_context) raw_api_context = ( raw_api_context.replace("\\n", " ") .replace("\n", " ") diff --git a/packages/dbgpt-core/src/dbgpt/util/__init__.py b/packages/dbgpt-core/src/dbgpt/util/__init__.py index bac9a7f8f..4ff38b5de 100644 --- a/packages/dbgpt-core/src/dbgpt/util/__init__.py +++ b/packages/dbgpt-core/src/dbgpt/util/__init__.py @@ -6,6 +6,7 @@ from .parameter_utils import ( # noqa: F401 EnvArgumentParser, ParameterDescription, ) +from .sql_utils import remove_sql_comments # noqa: F401 from .utils import get_gpu_memory, get_or_create_event_loop # noqa: F401 __ALL__ = [ @@ -17,4 +18,5 @@ __ALL__ = [ "EnvArgumentParser", "AppConfig", "RegisterParameters", + "remove_sql_comments", ] diff --git a/packages/dbgpt-core/src/dbgpt/util/sql_utils.py b/packages/dbgpt-core/src/dbgpt/util/sql_utils.py new file mode 100644 index 000000000..770d5ee7b --- /dev/null +++ b/packages/dbgpt-core/src/dbgpt/util/sql_utils.py @@ -0,0 +1,10 @@ +import re + + +def remove_sql_comments(sql: str) -> str: + """Remove SQL comments from the given SQL string.""" + + # Remove single-line comments (--) and multi-line comments (/* ... */) + sql = re.sub(r"--.*?(\n|$)", "", sql) + sql = re.sub(r"/\*.*?\*/", "", sql, flags=re.DOTALL) + return sql diff --git a/packages/dbgpt-core/src/dbgpt/util/tests/test_sql_utils.py b/packages/dbgpt-core/src/dbgpt/util/tests/test_sql_utils.py new file mode 100644 index 000000000..ca646ee57 --- /dev/null +++ b/packages/dbgpt-core/src/dbgpt/util/tests/test_sql_utils.py @@ -0,0 +1,200 @@ +import pytest + +from dbgpt.util.sql_utils import remove_sql_comments + + +class TestRemoveSqlComments: + """Test cases for the remove_sql_comments method.""" + + def test_remove_single_line_comments_basic(self): + """Test removing basic single-line comments.""" + sql = "SELECT * FROM users -- This is a comment" + expected = "SELECT * FROM users " + result = remove_sql_comments(sql) + assert result == expected + + def test_remove_single_line_comments_with_newline(self): + """Test removing single-line comments followed by newline.""" + sql = "SELECT * FROM users -- This is a comment\nWHERE id = 1" + expected = "SELECT * FROM users WHERE id = 1" + result = remove_sql_comments(sql) + assert result == expected + + def test_remove_single_line_comments_at_end(self): + """Test removing single-line comments at the end of string.""" + sql = "SELECT * FROM users\n-- Final comment" + expected = "SELECT * FROM users\n" + result = remove_sql_comments(sql) + assert result == expected + + def test_remove_multiple_single_line_comments(self): + """Test removing multiple single-line comments.""" + sql = """SELECT * FROM users -- Comment 1 + WHERE id = 1 -- Comment 2 + ORDER BY name -- Comment 3""" + expected = """SELECT * FROM users WHERE id = 1 ORDER BY name """ # noqa: E501 + result = remove_sql_comments(sql) + assert result == expected + + def test_remove_multiline_comments_basic(self): + """Test removing basic multi-line comments.""" + sql = "SELECT * FROM users /* This is a comment */ WHERE id = 1" + expected = "SELECT * FROM users WHERE id = 1" + result = remove_sql_comments(sql) + assert result == expected + + def test_remove_multiline_comments_with_newlines(self): + """Test removing multi-line comments spanning multiple lines.""" + sql = """SELECT * FROM users + /* This is a + multi-line comment */ + WHERE id = 1""" + expected = """SELECT * FROM users + + WHERE id = 1""" + result = remove_sql_comments(sql) + assert result == expected + + def test_remove_multiple_multiline_comments(self): + """Test removing multiple multi-line comments.""" + sql = "SELECT * /* comment 1 */ FROM users /* comment 2 */ WHERE id = 1" + expected = "SELECT * FROM users WHERE id = 1" + result = remove_sql_comments(sql) + assert result == expected + + def test_remove_mixed_comments(self): + """Test removing both single-line and multi-line comments.""" + sql = """SELECT * FROM users /* multi-line comment */ + WHERE id = 1 -- single line comment + /* another multi-line + comment */ + ORDER BY name -- final comment""" + expected = """SELECT * FROM users + WHERE id = 1 + ORDER BY name """ + result = remove_sql_comments(sql) + assert result == expected + + def test_no_comments_to_remove(self): + """Test SQL with no comments.""" + sql = "SELECT * FROM users WHERE id = 1 ORDER BY name" + expected = "SELECT * FROM users WHERE id = 1 ORDER BY name" + result = remove_sql_comments(sql) + assert result == expected + + def test_empty_string(self): + """Test with empty string.""" + sql = "" + expected = "" + result = remove_sql_comments(sql) + assert result == expected + + def test_only_comments(self): + """Test string with only comments.""" + sql = "-- This is only a comment" + expected = "" + result = remove_sql_comments(sql) + assert result == expected + + def test_only_multiline_comments(self): + """Test string with only multi-line comments.""" + sql = "/* This is only a multi-line comment */" + expected = "" + result = remove_sql_comments(sql) + assert result == expected + + def test_comments_with_special_characters(self): + """Test comments containing special characters.""" + sql = "SELECT * FROM users -- Comment with special chars: @#$%^&*()" + expected = "SELECT * FROM users " + result = remove_sql_comments(sql) + assert result == expected + + def test_comments_with_sql_keywords(self): + """Test comments containing SQL keywords.""" + sql = "SELECT * FROM users -- SELECT * FROM another_table" + expected = "SELECT * FROM users " + result = remove_sql_comments(sql) + assert result == expected + + def test_whitespace_handling(self): + """Test whitespace handling around comments.""" + sql = "SELECT * -- comment with spaces \n FROM users" + expected = "SELECT * FROM users" + result = remove_sql_comments(sql) + assert result == expected + + def test_multiline_comment_at_start(self): + """Test multi-line comment at the beginning.""" + sql = "/* Initial comment */ SELECT * FROM users" + expected = " SELECT * FROM users" + result = remove_sql_comments(sql) + assert result == expected + + def test_multiline_comment_at_end(self): + """Test multi-line comment at the end.""" + sql = "SELECT * FROM users /* Final comment */" + expected = "SELECT * FROM users " + result = remove_sql_comments(sql) + assert result == expected + + def test_complex_sql_with_comments(self): + """Test complex SQL statement with various comment types.""" + sql = """ + /* Query to get user information */ + SELECT + u.id, -- User ID + u.name, /* User full name */ + u.email -- Contact email + FROM users u + /* Join with posts table */ + LEFT JOIN posts p ON u.id = p.user_id + WHERE u.active = 1 -- Only active users + /* Order by creation date */ + ORDER BY u.created_at DESC + -- Limit results + LIMIT 10; + """ + + result = remove_sql_comments(sql) + + # Verify that all comments are removed + assert "--" not in result + assert "/*" not in result + assert "*/" not in result + + # Verify that SQL keywords are preserved + assert "SELECT" in result + assert "FROM" in result + assert "WHERE" in result + assert "ORDER BY" in result + assert "LIMIT" in result + + +# 如果需要测试实际的类实例,可以使用以下参数化测试 +@pytest.mark.parametrize( + "sql_input,expected_output", + [ + ("SELECT * FROM users", "SELECT * FROM users"), + ("SELECT * FROM users -- comment", "SELECT * FROM users "), + ("SELECT * /* comment */ FROM users", "SELECT * FROM users"), + ("", ""), + ("-- only comment", ""), + ("/* only comment */", ""), + ], +) +def test_remove_sql_comments_parametrized(sql_input, expected_output): + """Parametrized test for common cases.""" + + # 创建测试实例 + class MockClass: + def _remove_sql_comments(self, sql: str) -> str: + import re + + sql = re.sub(r"--.*?(\n|$)", "", sql) + sql = re.sub(r"/\*.*?\*/", "", sql, flags=re.DOTALL) + return sql + + instance = MockClass() + result = instance._remove_sql_comments(sql_input) + assert result == expected_output diff --git a/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/elastic_store.py b/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/elastic_store.py index f0e3ebc22..5be306e85 100644 --- a/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/elastic_store.py +++ b/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/elastic_store.py @@ -173,14 +173,14 @@ class ElasticStore(VectorStoreBase): connect_kwargs = {} elasticsearch_vector_config = vector_store_config.to_dict() self.uri = os.getenv( - "ELASTICSEARCH_URL", "localhost" - ) or elasticsearch_vector_config.get("uri") + "ELASTICSEARCH_URL", None + ) or elasticsearch_vector_config.get("uri", "localhost") self.port = os.getenv( - "ELASTICSEARCH_PORT", "9200" - ) or elasticsearch_vector_config.get("post") + "ELASTICSEARCH_PORT", None + ) or elasticsearch_vector_config.get("post", "9200") self.username = os.getenv( "ELASTICSEARCH_USERNAME" - ) or elasticsearch_vector_config.get("username") + ) or elasticsearch_vector_config.get("user") self.password = os.getenv( "ELASTICSEARCH_PASSWORD" ) or elasticsearch_vector_config.get("password")