feat(datasource): Support reasoning for ChatDashboard (#2401)

This commit is contained in:
Fangyin Cheng
2025-03-06 15:16:08 +08:00
committed by GitHub
parent 3bd75d8de2
commit bfd7fe8888
65 changed files with 391 additions and 216 deletions

View File

@@ -60,40 +60,12 @@ class DBSchemaAssembler(BaseAssembler):
self._connector = connector
self._table_vector_store_connector = table_vector_store_connector
self._field_vector_store_connector = field_vector_store_connector
# field_vector_store_config = VectorStoreConfig(
# name=table_vector_store_connector.vector_store_config.name + "_field"
# )
# self._field_vector_store_connector = (
# field_vector_store_connector
# or VectorStoreConnector.from_default(
# os.getenv("VECTOR_STORE_TYPE", "Chroma"),
# self._table_vector_store_connector.current_embeddings,
# vector_store_config=field_vector_store_config,
# )
# )
self._embedding_model = embedding_model
if self._embedding_model and not embeddings:
embeddings = DefaultEmbeddingFactory(
default_model_name=self._embedding_model
).create(self._embedding_model)
# if (
# embeddings
# and self._table_vector_store_connector.vector_store_config.embedding_fn
# is None
# ):
# self._table_vector_store_connector.vector_store_config.embedding_fn = (
# embeddings
# )
# if (
# embeddings
# and self._field_vector_store_connector.vector_store_config.embedding_fn
# is None
# ):
# self._field_vector_store_connector.vector_store_config.embedding_fn = (
# embeddings
# )
knowledge = DatasourceKnowledge(connector, model_dimension=max_seq_length)
super().__init__(
knowledge=knowledge,

View File

@@ -11,7 +11,11 @@ from dbgpt.rag.knowledge.base import (
KnowledgeType,
)
from ..summary.rdbms_db_summary import _parse_db_summary_with_metadata
from ..summary.rdbms_db_summary import (
_DEFAULT_COLUMN_SEPARATOR,
_DEFAULT_SUMMARY_TEMPLATE,
_parse_db_summary_with_metadata,
)
class DatasourceKnowledge(Knowledge):
@@ -20,8 +24,9 @@ class DatasourceKnowledge(Knowledge):
def __init__(
self,
connector: BaseConnector,
summary_template: str = "table_name: {table_name}",
summary_template: str = _DEFAULT_SUMMARY_TEMPLATE,
separator: str = "--table-field-separator--",
column_separator: str = _DEFAULT_COLUMN_SEPARATOR,
knowledge_type: Optional[KnowledgeType] = KnowledgeType.DOCUMENT,
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
model_dimension: int = 512,
@@ -40,6 +45,7 @@ class DatasourceKnowledge(Knowledge):
model_dimension(int, optional): The threshold for splitting field string
"""
self._separator = separator
self._column_separator = column_separator
self._connector = connector
self._summary_template = summary_template
self._model_dimension = model_dimension
@@ -52,7 +58,8 @@ class DatasourceKnowledge(Knowledge):
self._connector,
self._summary_template,
self._separator,
self._model_dimension,
column_separator=self._column_separator,
model_dimension=self._model_dimension,
)
for summary, table_metadata in db_summary_with_metadata:
metadata = {"source": "database"}

View File

@@ -13,7 +13,11 @@ from dbgpt.storage.vector_store.filters import MetadataFilter, MetadataFilters
from dbgpt.util.chat_util import run_tasks
from dbgpt.util.executor_utils import blocking_func_to_async_no_executor
from ..summary.rdbms_db_summary import _parse_db_summary
from ..summary.rdbms_db_summary import (
_DEFAULT_COLUMN_SEPARATOR,
_parse_db_summary,
_parse_table_detail,
)
logger = logging.getLogger(__name__)
@@ -28,6 +32,7 @@ class DBSchemaRetriever(BaseRetriever):
table_vector_store_connector: VectorStoreBase,
field_vector_store_connector: VectorStoreBase = None,
separator: str = "--table-field-separator--",
column_separator: str = _DEFAULT_COLUMN_SEPARATOR,
top_k: int = 4,
connector: Optional[BaseConnector] = None,
query_rewrite: bool = False,
@@ -100,6 +105,7 @@ class DBSchemaRetriever(BaseRetriever):
print(f"db struct rag example results:{result}")
"""
self._separator = separator
self._column_separator = column_separator
self._top_k = top_k
self._connector = connector
self._query_rewrite = query_rewrite
@@ -186,9 +192,11 @@ class DBSchemaRetriever(BaseRetriever):
field_chunks = self._field_vector_store_connector.similar_search_with_scores(
query, self._top_k, 0, MetadataFilters(filters=filters)
)
field_contents = [chunk.content for chunk in field_chunks]
table_chunk.content += "\n" + self._separator + "\n" + "\n".join(field_contents)
return table_chunk
field_contents = [chunk.content.strip() for chunk in field_chunks]
table_chunk.content += (
"\n" + self._separator + "\n" + self._column_separator.join(field_contents)
)
return self._deserialize_table_chunk(table_chunk)
def _similarity_search(
self, query, filters: Optional[MetadataFilters] = None
@@ -198,6 +206,7 @@ class DBSchemaRetriever(BaseRetriever):
query, self._top_k, 0, filters
)
# Find all table chunks which are not separated
not_sep_chunks = [
chunk for chunk in table_chunks if not chunk.metadata.get("separated")
]
@@ -205,9 +214,11 @@ class DBSchemaRetriever(BaseRetriever):
chunk for chunk in table_chunks if chunk.metadata.get("separated")
]
if not separated_chunks:
return not_sep_chunks
return [self._deserialize_table_chunk(chunk) for chunk in not_sep_chunks]
# Create tasks list
# The fields of table is too large, and it has to be separated into chunks,
# so we need to retrieve fields of each table separately
tasks = [
lambda c=chunk: self._retrieve_field(c, query) for chunk in separated_chunks
]
@@ -216,3 +227,32 @@ class DBSchemaRetriever(BaseRetriever):
# Combine and return results
return not_sep_chunks + separated_result
def _deserialize_table_chunk(self, chunk: Chunk) -> Chunk:
"""Deserialize table chunk."""
db_summary_version = chunk.metadata.get("db_summary_version")
if not db_summary_version:
return chunk
parts = chunk.content.split(self._separator)
table_part, field_part = parts[0].strip(), parts[1].strip()
table_detail = _parse_table_detail(table_part)
table_name = table_detail.get("table_name")
table_comment = table_detail.get("table_comment")
index_keys = table_detail.get("index_keys")
table_name = table_name.strip() if table_name else table_name
table_comment = table_comment.strip() if table_comment else table_comment
index_keys = index_keys.strip() if index_keys else index_keys
if not table_name:
return chunk
create_statement = f'CREATE TABLE "{table_name}"\r\n(\r\n '
create_statement += field_part
create_statement += "\r\n)"
if table_comment:
create_statement += f' COMMENT "{table_comment}"\r\n'
if index_keys:
create_statement += f"Index keys: {index_keys}"
chunk.content = create_statement
return chunk

View File

@@ -13,6 +13,35 @@ if TYPE_CHECKING:
CFG = Config()
_DEFAULT_SUMMARY_TEMPLATE = """\
table_name: {table_name}\r\n\
table_comment: {table_comment}\r\n\
index_keys: {index_keys}\r\n\
"""
_DEFAULT_SUMMARY_TEMPLATE_PATTEN = (
r"table_name:\s*(?P<table_name>.*)\s*"
r"table_comment:\s*(?P<table_comment>.*)\s*"
r"index_keys:\s*(?P<index_keys>.*)\s*"
)
_DEFAULT_COLUMN_SEPARATOR = ",\r\n "
def _parse_table_detail(table_desc_str: str) -> Dict[str, Any]:
"""Parse table detail string.
Args:
table_desc_str (str): table detail string
Returns:
Dict[str, Any]: A dictionary containing table_name, table_comment, and
index_keys.
"""
matched = re.match(_DEFAULT_SUMMARY_TEMPLATE_PATTEN, table_desc_str)
if matched:
return matched.groupdict()
return {}
class RdbmsSummary(DBSummary):
"""Get rdbms db table summary template.
@@ -83,8 +112,9 @@ def _parse_db_summary(
def _parse_db_summary_with_metadata(
conn: BaseConnector,
summary_template: str = "table_name: {table_name}",
summary_template: str = _DEFAULT_SUMMARY_TEMPLATE,
separator: str = "--table-field-separator--",
column_separator: str = _DEFAULT_COLUMN_SEPARATOR,
model_dimension: int = 512,
) -> List[Tuple[str, Dict[str, Any]]]:
"""Get db summary for database.
@@ -99,14 +129,21 @@ def _parse_db_summary_with_metadata(
tables = conn.get_table_names()
table_info_summaries = [
_parse_table_summary_with_metadata(
conn, summary_template, separator, table_name, model_dimension
conn,
summary_template,
separator,
table_name,
model_dimension,
column_separator=column_separator,
)
for table_name in tables
]
return table_info_summaries
def _split_columns_str(columns: List[str], model_dimension: int):
def _split_columns_str(
columns: List[str], model_dimension: int, column_separator: str = ",\r\n "
):
"""Split columns str.
Args:
@@ -129,7 +166,7 @@ def _split_columns_str(columns: List[str], model_dimension: int):
else:
# If current string is empty, add element directly
if current_string:
current_string += "," + element_str
current_string += column_separator + element_str
else:
current_string = element_str
current_length += element_length + 1 # Add length of space
@@ -147,6 +184,8 @@ def _parse_table_summary_with_metadata(
separator,
table_name: str,
model_dimension=512,
column_separator: str = _DEFAULT_COLUMN_SEPARATOR,
db_summary_version: str = "v1.0",
) -> Tuple[str, Dict[str, Any]]:
"""Get table summary for table.
@@ -168,17 +207,26 @@ def _parse_table_summary_with_metadata(
(column4,comment), (column5, comment), (column6, comment)
"""
columns = []
metadata = {"table_name": table_name, "separated": 0}
metadata = {
"table_name": table_name,
"separated": 0,
"db_summary_version": db_summary_version,
}
for column in conn.get_columns(table_name):
if column.get("comment"):
columns.append(f"{column['name']} ({column.get('comment')})")
else:
columns.append(f"{column['name']}")
col_name = column["name"]
col_type = str(column["type"]) if "type" in column else None
col_comment = column.get("comment")
column_def = f'"{col_name}" {col_type.upper()}'
if col_comment:
column_def += f' COMMENT "{col_comment}"'
columns.append(column_def)
metadata.update({"field_num": len(columns)})
separated_columns = _split_columns_str(columns, model_dimension=model_dimension)
separated_columns = _split_columns_str(
columns, model_dimension=model_dimension, column_separator=column_separator
)
if len(separated_columns) > 1:
metadata["separated"] = 1
column_str = "\n".join(separated_columns)
column_str = column_separator.join(separated_columns)
# Obtain index information
index_keys = []
raw_indexes = conn.get_indexes(table_name)
@@ -193,18 +241,19 @@ def _parse_table_summary_with_metadata(
else:
key_str = ", ".join(index["column_names"])
index_keys.append(f"{index['name']}(`{key_str}`) ")
table_str = summary_template.format(table_name=table_name)
table_comment = ""
try:
comment = conn.get_table_comment(table_name)
table_comment = comment.get("text")
except Exception:
comment = dict(text=None)
if comment.get("text"):
table_str += f"\ntable_comment: {comment.get('text')}"
pass
if len(index_keys) > 0:
index_key_str = ", ".join(index_keys)
table_str += f"\nindex_keys: {index_key_str}"
index_key_str = ", ".join(index_keys)
table_str = summary_template.format(
table_name=table_name, table_comment=table_comment, index_keys=index_key_str
)
table_str += f"\n{separator}\n{column_str}"
return table_str, metadata