mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-22 03:41:43 +00:00
fix connect mssql embedding error (#2589)
This commit is contained in:
parent
8d66d0271f
commit
129509fcd2
@ -77,3 +77,216 @@ class MSSQLConnector(RDBMSConnector):
|
||||
table_colums.append(field_info[0])
|
||||
results.append(f"{table_name}({','.join(table_colums)});")
|
||||
return results
|
||||
|
||||
def get_users(self):
|
||||
with self.session_scope() as session:
|
||||
cursor = session.execute(
|
||||
text(
|
||||
"SELECT name FROM sys.server_principals "
|
||||
"WHERE type_desc = 'SQL_LOGIN'"
|
||||
)
|
||||
)
|
||||
return [row[0] for row in cursor.fetchall()]
|
||||
|
||||
def get_grants(self):
|
||||
with self.session_scope() as session:
|
||||
query = """
|
||||
SELECT
|
||||
CASE WHEN perm.state <> 'W' THEN perm.state_desc ELSE 'GRANT WITH
|
||||
GRANT OPTION' END AS [Permission],
|
||||
perm.permission_name AS [Permission Name],
|
||||
CASE
|
||||
WHEN perm.class = 0 THEN 'SERVER'
|
||||
WHEN perm.class = 1 THEN OBJECT_NAME(perm.major_id)
|
||||
WHEN perm.class = 3 THEN SCHEMA_NAME(perm.major_id)
|
||||
ELSE CAST(perm.class AS VARCHAR)
|
||||
END AS [Securable],
|
||||
princ.name AS [Principal]
|
||||
FROM
|
||||
sys.server_permissions perm
|
||||
JOIN sys.server_principals princ ON perm.grantee_principal_id =
|
||||
princ.principal_id
|
||||
"""
|
||||
cursor = session.execute(text(query))
|
||||
return cursor.fetchall()
|
||||
|
||||
def _decode_if_bytes(self, value):
|
||||
if isinstance(value, bytes):
|
||||
return value.decode("utf-8")
|
||||
return value
|
||||
|
||||
def get_charset(self):
|
||||
with self.session_scope() as session:
|
||||
query = (
|
||||
"SELECT DATABASEPROPERTYEX(DB_NAME(), 'Collation') AS DatabaseCollation"
|
||||
)
|
||||
cursor = session.execute(text(query))
|
||||
result = cursor.fetchone()
|
||||
|
||||
if result and result[0]:
|
||||
collation = self._decode_if_bytes(result[0])
|
||||
parts = collation.split("_")
|
||||
if len(parts) >= 2:
|
||||
return parts[1]
|
||||
return collation
|
||||
|
||||
return "SQL_Server_Default"
|
||||
|
||||
def get_collation(self):
|
||||
with self.session_scope() as session:
|
||||
cursor = session.execute(
|
||||
text("SELECT SERVERPROPERTY('Collation') AS DatabaseCollation")
|
||||
)
|
||||
collation = cursor.fetchone()[0]
|
||||
return collation
|
||||
|
||||
def get_table_names(self):
|
||||
tables = []
|
||||
|
||||
with self.session_scope() as session:
|
||||
query = """
|
||||
SELECT
|
||||
TABLE_SCHEMA + '.' + TABLE_NAME AS full_table_name
|
||||
FROM
|
||||
INFORMATION_SCHEMA.TABLES
|
||||
WHERE
|
||||
TABLE_TYPE = 'BASE TABLE'
|
||||
AND TABLE_CATALOG = DB_NAME()
|
||||
"""
|
||||
|
||||
cursor = session.execute(text(query))
|
||||
tables = [row[0] for row in cursor.fetchall()]
|
||||
|
||||
if not tables:
|
||||
query = """
|
||||
SELECT
|
||||
SCHEMA_NAME(schema_id) + '.' + name AS full_table_name
|
||||
FROM
|
||||
sys.tables
|
||||
"""
|
||||
cursor = session.execute(text(query))
|
||||
tables = [row[0] for row in cursor.fetchall()]
|
||||
|
||||
if not tables:
|
||||
query = """
|
||||
SELECT
|
||||
name AS table_name
|
||||
FROM
|
||||
sys.tables
|
||||
"""
|
||||
cursor = session.execute(text(query))
|
||||
tables = [row[0] for row in cursor.fetchall()]
|
||||
|
||||
return tables
|
||||
|
||||
def get_columns(self, table_name: str):
|
||||
if "." in table_name:
|
||||
schema_name, pure_table_name = table_name.split(".", 1)
|
||||
else:
|
||||
schema_name = "dbo"
|
||||
pure_table_name = table_name
|
||||
|
||||
with self.session_scope() as session:
|
||||
query = """
|
||||
SELECT
|
||||
COLUMN_NAME AS name,
|
||||
DATA_TYPE AS type,
|
||||
CASE WHEN IS_NULLABLE = 'YES' THEN 1 ELSE 0 END AS nullable,
|
||||
COLUMN_DEFAULT AS default_value,
|
||||
CHARACTER_MAXIMUM_LENGTH AS max_length
|
||||
FROM
|
||||
INFORMATION_SCHEMA.COLUMNS
|
||||
WHERE
|
||||
TABLE_SCHEMA = :schema
|
||||
AND TABLE_NAME = :table
|
||||
ORDER BY
|
||||
ORDINAL_POSITION
|
||||
"""
|
||||
cursor = session.execute(
|
||||
text(query), {"schema": schema_name, "table": pure_table_name}
|
||||
)
|
||||
results = cursor.fetchall()
|
||||
|
||||
columns = []
|
||||
for row in results:
|
||||
name = row[0].decode("utf-8") if isinstance(row[0], bytes) else row[0]
|
||||
col_type = (
|
||||
row[1].decode("utf-8") if isinstance(row[1], bytes) else row[1]
|
||||
)
|
||||
column = {
|
||||
"name": name,
|
||||
"type": col_type,
|
||||
"nullable": bool(row[2]),
|
||||
}
|
||||
|
||||
if row[3] is not None:
|
||||
default = (
|
||||
row[3].decode("utf-8") if isinstance(row[3], bytes) else row[3]
|
||||
)
|
||||
column["default"] = default
|
||||
|
||||
if row[4] is not None:
|
||||
column["max_length"] = row[4]
|
||||
columns.append(column)
|
||||
|
||||
return columns
|
||||
|
||||
def get_indexes(self, table_name: str):
|
||||
if "." in table_name:
|
||||
schema_name, pure_table_name = table_name.split(".", 1)
|
||||
else:
|
||||
schema_name = "dbo"
|
||||
pure_table_name = table_name
|
||||
|
||||
with self.session_scope() as session:
|
||||
query = """
|
||||
SELECT
|
||||
i.name AS index_name,
|
||||
c.name AS column_name,
|
||||
i.is_unique AS is_unique,
|
||||
i.is_primary_key AS is_primary_key
|
||||
FROM
|
||||
sys.indexes i
|
||||
INNER JOIN
|
||||
sys.index_columns ic ON i.object_id = ic.object_id AND i.index_id
|
||||
= ic.index_id
|
||||
INNER JOIN
|
||||
sys.columns c ON ic.object_id = c.object_id AND ic.column_id
|
||||
= c.column_id
|
||||
INNER JOIN
|
||||
sys.tables t ON i.object_id = t.object_id
|
||||
INNER JOIN
|
||||
sys.schemas s ON t.schema_id = s.schema_id
|
||||
WHERE
|
||||
t.name = :table
|
||||
AND s.name = :schema
|
||||
AND i.name IS NOT NULL
|
||||
ORDER BY
|
||||
i.name, ic.key_ordinal
|
||||
"""
|
||||
cursor = session.execute(
|
||||
text(query), {"schema": schema_name, "table": pure_table_name}
|
||||
)
|
||||
results = cursor.fetchall()
|
||||
|
||||
index_dict = {}
|
||||
for row in results:
|
||||
index_name = (
|
||||
row[0].decode("utf-8") if isinstance(row[0], bytes) else row[0]
|
||||
)
|
||||
column_name = (
|
||||
row[1].decode("utf-8") if isinstance(row[1], bytes) else row[1]
|
||||
)
|
||||
is_unique = bool(row[2])
|
||||
is_primary_key = bool(row[3])
|
||||
if index_name not in index_dict:
|
||||
index_dict[index_name] = {
|
||||
"name": index_name,
|
||||
"column_names": [],
|
||||
"unique": is_unique,
|
||||
"primary": is_primary_key,
|
||||
}
|
||||
|
||||
index_dict[index_name]["column_names"].append(column_name)
|
||||
|
||||
return list(index_dict.values())
|
||||
|
Loading…
Reference in New Issue
Block a user