From 129509fcd24ea86fcb12b1fd242f11ec417d853d Mon Sep 17 00:00:00 2001 From: dobet <247479943@qq.com> Date: Thu, 17 Apr 2025 12:44:51 +0800 Subject: [PATCH] fix connect mssql embedding error (#2589) --- .../dbgpt_ext/datasource/rdbms/conn_mssql.py | 213 ++++++++++++++++++ 1 file changed, 213 insertions(+) diff --git a/packages/dbgpt-ext/src/dbgpt_ext/datasource/rdbms/conn_mssql.py b/packages/dbgpt-ext/src/dbgpt_ext/datasource/rdbms/conn_mssql.py index ef311afac..0cd03f5e0 100644 --- a/packages/dbgpt-ext/src/dbgpt_ext/datasource/rdbms/conn_mssql.py +++ b/packages/dbgpt-ext/src/dbgpt_ext/datasource/rdbms/conn_mssql.py @@ -77,3 +77,216 @@ class MSSQLConnector(RDBMSConnector): table_colums.append(field_info[0]) results.append(f"{table_name}({','.join(table_colums)});") return results + + def get_users(self): + with self.session_scope() as session: + cursor = session.execute( + text( + "SELECT name FROM sys.server_principals " + "WHERE type_desc = 'SQL_LOGIN'" + ) + ) + return [row[0] for row in cursor.fetchall()] + + def get_grants(self): + with self.session_scope() as session: + query = """ + SELECT + CASE WHEN perm.state <> 'W' THEN perm.state_desc ELSE 'GRANT WITH + GRANT OPTION' END AS [Permission], + perm.permission_name AS [Permission Name], + CASE + WHEN perm.class = 0 THEN 'SERVER' + WHEN perm.class = 1 THEN OBJECT_NAME(perm.major_id) + WHEN perm.class = 3 THEN SCHEMA_NAME(perm.major_id) + ELSE CAST(perm.class AS VARCHAR) + END AS [Securable], + princ.name AS [Principal] + FROM + sys.server_permissions perm + JOIN sys.server_principals princ ON perm.grantee_principal_id = + princ.principal_id + """ + cursor = session.execute(text(query)) + return cursor.fetchall() + + def _decode_if_bytes(self, value): + if isinstance(value, bytes): + return value.decode("utf-8") + return value + + def get_charset(self): + with self.session_scope() as session: + query = ( + "SELECT DATABASEPROPERTYEX(DB_NAME(), 'Collation') AS DatabaseCollation" + ) + cursor = session.execute(text(query)) + result = cursor.fetchone() + + if result and result[0]: + collation = self._decode_if_bytes(result[0]) + parts = collation.split("_") + if len(parts) >= 2: + return parts[1] + return collation + + return "SQL_Server_Default" + + def get_collation(self): + with self.session_scope() as session: + cursor = session.execute( + text("SELECT SERVERPROPERTY('Collation') AS DatabaseCollation") + ) + collation = cursor.fetchone()[0] + return collation + + def get_table_names(self): + tables = [] + + with self.session_scope() as session: + query = """ + SELECT + TABLE_SCHEMA + '.' + TABLE_NAME AS full_table_name + FROM + INFORMATION_SCHEMA.TABLES + WHERE + TABLE_TYPE = 'BASE TABLE' + AND TABLE_CATALOG = DB_NAME() + """ + + cursor = session.execute(text(query)) + tables = [row[0] for row in cursor.fetchall()] + + if not tables: + query = """ + SELECT + SCHEMA_NAME(schema_id) + '.' + name AS full_table_name + FROM + sys.tables + """ + cursor = session.execute(text(query)) + tables = [row[0] for row in cursor.fetchall()] + + if not tables: + query = """ + SELECT + name AS table_name + FROM + sys.tables + """ + cursor = session.execute(text(query)) + tables = [row[0] for row in cursor.fetchall()] + + return tables + + def get_columns(self, table_name: str): + if "." in table_name: + schema_name, pure_table_name = table_name.split(".", 1) + else: + schema_name = "dbo" + pure_table_name = table_name + + with self.session_scope() as session: + query = """ + SELECT + COLUMN_NAME AS name, + DATA_TYPE AS type, + CASE WHEN IS_NULLABLE = 'YES' THEN 1 ELSE 0 END AS nullable, + COLUMN_DEFAULT AS default_value, + CHARACTER_MAXIMUM_LENGTH AS max_length + FROM + INFORMATION_SCHEMA.COLUMNS + WHERE + TABLE_SCHEMA = :schema + AND TABLE_NAME = :table + ORDER BY + ORDINAL_POSITION + """ + cursor = session.execute( + text(query), {"schema": schema_name, "table": pure_table_name} + ) + results = cursor.fetchall() + + columns = [] + for row in results: + name = row[0].decode("utf-8") if isinstance(row[0], bytes) else row[0] + col_type = ( + row[1].decode("utf-8") if isinstance(row[1], bytes) else row[1] + ) + column = { + "name": name, + "type": col_type, + "nullable": bool(row[2]), + } + + if row[3] is not None: + default = ( + row[3].decode("utf-8") if isinstance(row[3], bytes) else row[3] + ) + column["default"] = default + + if row[4] is not None: + column["max_length"] = row[4] + columns.append(column) + + return columns + + def get_indexes(self, table_name: str): + if "." in table_name: + schema_name, pure_table_name = table_name.split(".", 1) + else: + schema_name = "dbo" + pure_table_name = table_name + + with self.session_scope() as session: + query = """ + SELECT + i.name AS index_name, + c.name AS column_name, + i.is_unique AS is_unique, + i.is_primary_key AS is_primary_key + FROM + sys.indexes i + INNER JOIN + sys.index_columns ic ON i.object_id = ic.object_id AND i.index_id + = ic.index_id + INNER JOIN + sys.columns c ON ic.object_id = c.object_id AND ic.column_id + = c.column_id + INNER JOIN + sys.tables t ON i.object_id = t.object_id + INNER JOIN + sys.schemas s ON t.schema_id = s.schema_id + WHERE + t.name = :table + AND s.name = :schema + AND i.name IS NOT NULL + ORDER BY + i.name, ic.key_ordinal + """ + cursor = session.execute( + text(query), {"schema": schema_name, "table": pure_table_name} + ) + results = cursor.fetchall() + + index_dict = {} + for row in results: + index_name = ( + row[0].decode("utf-8") if isinstance(row[0], bytes) else row[0] + ) + column_name = ( + row[1].decode("utf-8") if isinstance(row[1], bytes) else row[1] + ) + is_unique = bool(row[2]) + is_primary_key = bool(row[3]) + if index_name not in index_dict: + index_dict[index_name] = { + "name": index_name, + "column_names": [], + "unique": is_unique, + "primary": is_primary_key, + } + + index_dict[index_name]["column_names"].append(column_name) + + return list(index_dict.values())