diff --git a/dbgpt/rag/summary/rdbms_db_summary.py b/dbgpt/rag/summary/rdbms_db_summary.py index 6397b53d3..337d3851b 100644 --- a/dbgpt/rag/summary/rdbms_db_summary.py +++ b/dbgpt/rag/summary/rdbms_db_summary.py @@ -1,5 +1,5 @@ """Summary for rdbms database.""" - +import re from typing import TYPE_CHECKING, List, Optional from dbgpt._private.config import Config @@ -102,10 +102,20 @@ def _parse_table_summary( columns.append(f"{column['name']}") column_str = ", ".join(columns) + # Obtain index information index_keys = [] - for index_key in conn.get_indexes(table_name): - key_str = ", ".join(index_key["column_names"]) - index_keys.append(f"{index_key['name']}(`{key_str}`) ") # noqa + raw_indexes = conn.get_indexes(table_name) + for index in raw_indexes: + if isinstance(index, tuple): # Process tuple type index information + index_name, index_creation_command = index + # Extract column names using re + matched_columns = re.findall(r"\(([^)]+)\)", index_creation_command) + if matched_columns: + key_str = ", ".join(matched_columns) + index_keys.append(f"{index_name}(`{key_str}`) ") + else: + key_str = ", ".join(index["column_names"]) + index_keys.append(f"{index['name']}(`{key_str}`) ") table_str = summary_template.format(table_name=table_name, columns=column_str) if len(index_keys) > 0: index_key_str = ", ".join(index_keys)