mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-08 20:39:44 +00:00
feat(datasource): Support reasoning for ChatDashboard (#2401)
This commit is contained in:
@@ -60,40 +60,12 @@ class DBSchemaAssembler(BaseAssembler):
|
||||
self._connector = connector
|
||||
self._table_vector_store_connector = table_vector_store_connector
|
||||
self._field_vector_store_connector = field_vector_store_connector
|
||||
# field_vector_store_config = VectorStoreConfig(
|
||||
# name=table_vector_store_connector.vector_store_config.name + "_field"
|
||||
# )
|
||||
# self._field_vector_store_connector = (
|
||||
# field_vector_store_connector
|
||||
# or VectorStoreConnector.from_default(
|
||||
# os.getenv("VECTOR_STORE_TYPE", "Chroma"),
|
||||
# self._table_vector_store_connector.current_embeddings,
|
||||
# vector_store_config=field_vector_store_config,
|
||||
# )
|
||||
# )
|
||||
|
||||
self._embedding_model = embedding_model
|
||||
if self._embedding_model and not embeddings:
|
||||
embeddings = DefaultEmbeddingFactory(
|
||||
default_model_name=self._embedding_model
|
||||
).create(self._embedding_model)
|
||||
|
||||
# if (
|
||||
# embeddings
|
||||
# and self._table_vector_store_connector.vector_store_config.embedding_fn
|
||||
# is None
|
||||
# ):
|
||||
# self._table_vector_store_connector.vector_store_config.embedding_fn = (
|
||||
# embeddings
|
||||
# )
|
||||
# if (
|
||||
# embeddings
|
||||
# and self._field_vector_store_connector.vector_store_config.embedding_fn
|
||||
# is None
|
||||
# ):
|
||||
# self._field_vector_store_connector.vector_store_config.embedding_fn = (
|
||||
# embeddings
|
||||
# )
|
||||
knowledge = DatasourceKnowledge(connector, model_dimension=max_seq_length)
|
||||
super().__init__(
|
||||
knowledge=knowledge,
|
||||
|
@@ -11,7 +11,11 @@ from dbgpt.rag.knowledge.base import (
|
||||
KnowledgeType,
|
||||
)
|
||||
|
||||
from ..summary.rdbms_db_summary import _parse_db_summary_with_metadata
|
||||
from ..summary.rdbms_db_summary import (
|
||||
_DEFAULT_COLUMN_SEPARATOR,
|
||||
_DEFAULT_SUMMARY_TEMPLATE,
|
||||
_parse_db_summary_with_metadata,
|
||||
)
|
||||
|
||||
|
||||
class DatasourceKnowledge(Knowledge):
|
||||
@@ -20,8 +24,9 @@ class DatasourceKnowledge(Knowledge):
|
||||
def __init__(
|
||||
self,
|
||||
connector: BaseConnector,
|
||||
summary_template: str = "table_name: {table_name}",
|
||||
summary_template: str = _DEFAULT_SUMMARY_TEMPLATE,
|
||||
separator: str = "--table-field-separator--",
|
||||
column_separator: str = _DEFAULT_COLUMN_SEPARATOR,
|
||||
knowledge_type: Optional[KnowledgeType] = KnowledgeType.DOCUMENT,
|
||||
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
|
||||
model_dimension: int = 512,
|
||||
@@ -40,6 +45,7 @@ class DatasourceKnowledge(Knowledge):
|
||||
model_dimension(int, optional): The threshold for splitting field string
|
||||
"""
|
||||
self._separator = separator
|
||||
self._column_separator = column_separator
|
||||
self._connector = connector
|
||||
self._summary_template = summary_template
|
||||
self._model_dimension = model_dimension
|
||||
@@ -52,7 +58,8 @@ class DatasourceKnowledge(Knowledge):
|
||||
self._connector,
|
||||
self._summary_template,
|
||||
self._separator,
|
||||
self._model_dimension,
|
||||
column_separator=self._column_separator,
|
||||
model_dimension=self._model_dimension,
|
||||
)
|
||||
for summary, table_metadata in db_summary_with_metadata:
|
||||
metadata = {"source": "database"}
|
||||
|
@@ -13,7 +13,11 @@ from dbgpt.storage.vector_store.filters import MetadataFilter, MetadataFilters
|
||||
from dbgpt.util.chat_util import run_tasks
|
||||
from dbgpt.util.executor_utils import blocking_func_to_async_no_executor
|
||||
|
||||
from ..summary.rdbms_db_summary import _parse_db_summary
|
||||
from ..summary.rdbms_db_summary import (
|
||||
_DEFAULT_COLUMN_SEPARATOR,
|
||||
_parse_db_summary,
|
||||
_parse_table_detail,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -28,6 +32,7 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
table_vector_store_connector: VectorStoreBase,
|
||||
field_vector_store_connector: VectorStoreBase = None,
|
||||
separator: str = "--table-field-separator--",
|
||||
column_separator: str = _DEFAULT_COLUMN_SEPARATOR,
|
||||
top_k: int = 4,
|
||||
connector: Optional[BaseConnector] = None,
|
||||
query_rewrite: bool = False,
|
||||
@@ -100,6 +105,7 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
print(f"db struct rag example results:{result}")
|
||||
"""
|
||||
self._separator = separator
|
||||
self._column_separator = column_separator
|
||||
self._top_k = top_k
|
||||
self._connector = connector
|
||||
self._query_rewrite = query_rewrite
|
||||
@@ -186,9 +192,11 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
field_chunks = self._field_vector_store_connector.similar_search_with_scores(
|
||||
query, self._top_k, 0, MetadataFilters(filters=filters)
|
||||
)
|
||||
field_contents = [chunk.content for chunk in field_chunks]
|
||||
table_chunk.content += "\n" + self._separator + "\n" + "\n".join(field_contents)
|
||||
return table_chunk
|
||||
field_contents = [chunk.content.strip() for chunk in field_chunks]
|
||||
table_chunk.content += (
|
||||
"\n" + self._separator + "\n" + self._column_separator.join(field_contents)
|
||||
)
|
||||
return self._deserialize_table_chunk(table_chunk)
|
||||
|
||||
def _similarity_search(
|
||||
self, query, filters: Optional[MetadataFilters] = None
|
||||
@@ -198,6 +206,7 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
query, self._top_k, 0, filters
|
||||
)
|
||||
|
||||
# Find all table chunks which are not separated
|
||||
not_sep_chunks = [
|
||||
chunk for chunk in table_chunks if not chunk.metadata.get("separated")
|
||||
]
|
||||
@@ -205,9 +214,11 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
chunk for chunk in table_chunks if chunk.metadata.get("separated")
|
||||
]
|
||||
if not separated_chunks:
|
||||
return not_sep_chunks
|
||||
return [self._deserialize_table_chunk(chunk) for chunk in not_sep_chunks]
|
||||
|
||||
# Create tasks list
|
||||
# The fields of table is too large, and it has to be separated into chunks,
|
||||
# so we need to retrieve fields of each table separately
|
||||
tasks = [
|
||||
lambda c=chunk: self._retrieve_field(c, query) for chunk in separated_chunks
|
||||
]
|
||||
@@ -216,3 +227,32 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
|
||||
# Combine and return results
|
||||
return not_sep_chunks + separated_result
|
||||
|
||||
def _deserialize_table_chunk(self, chunk: Chunk) -> Chunk:
|
||||
"""Deserialize table chunk."""
|
||||
db_summary_version = chunk.metadata.get("db_summary_version")
|
||||
if not db_summary_version:
|
||||
return chunk
|
||||
parts = chunk.content.split(self._separator)
|
||||
table_part, field_part = parts[0].strip(), parts[1].strip()
|
||||
table_detail = _parse_table_detail(table_part)
|
||||
table_name = table_detail.get("table_name")
|
||||
table_comment = table_detail.get("table_comment")
|
||||
index_keys = table_detail.get("index_keys")
|
||||
|
||||
table_name = table_name.strip() if table_name else table_name
|
||||
table_comment = table_comment.strip() if table_comment else table_comment
|
||||
index_keys = index_keys.strip() if index_keys else index_keys
|
||||
if not table_name:
|
||||
return chunk
|
||||
|
||||
create_statement = f'CREATE TABLE "{table_name}"\r\n(\r\n '
|
||||
create_statement += field_part
|
||||
create_statement += "\r\n)"
|
||||
if table_comment:
|
||||
create_statement += f' COMMENT "{table_comment}"\r\n'
|
||||
if index_keys:
|
||||
create_statement += f"Index keys: {index_keys}"
|
||||
|
||||
chunk.content = create_statement
|
||||
return chunk
|
||||
|
@@ -13,6 +13,35 @@ if TYPE_CHECKING:
|
||||
CFG = Config()
|
||||
|
||||
|
||||
_DEFAULT_SUMMARY_TEMPLATE = """\
|
||||
table_name: {table_name}\r\n\
|
||||
table_comment: {table_comment}\r\n\
|
||||
index_keys: {index_keys}\r\n\
|
||||
"""
|
||||
_DEFAULT_SUMMARY_TEMPLATE_PATTEN = (
|
||||
r"table_name:\s*(?P<table_name>.*)\s*"
|
||||
r"table_comment:\s*(?P<table_comment>.*)\s*"
|
||||
r"index_keys:\s*(?P<index_keys>.*)\s*"
|
||||
)
|
||||
_DEFAULT_COLUMN_SEPARATOR = ",\r\n "
|
||||
|
||||
|
||||
def _parse_table_detail(table_desc_str: str) -> Dict[str, Any]:
|
||||
"""Parse table detail string.
|
||||
|
||||
Args:
|
||||
table_desc_str (str): table detail string
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary containing table_name, table_comment, and
|
||||
index_keys.
|
||||
"""
|
||||
matched = re.match(_DEFAULT_SUMMARY_TEMPLATE_PATTEN, table_desc_str)
|
||||
if matched:
|
||||
return matched.groupdict()
|
||||
return {}
|
||||
|
||||
|
||||
class RdbmsSummary(DBSummary):
|
||||
"""Get rdbms db table summary template.
|
||||
|
||||
@@ -83,8 +112,9 @@ def _parse_db_summary(
|
||||
|
||||
def _parse_db_summary_with_metadata(
|
||||
conn: BaseConnector,
|
||||
summary_template: str = "table_name: {table_name}",
|
||||
summary_template: str = _DEFAULT_SUMMARY_TEMPLATE,
|
||||
separator: str = "--table-field-separator--",
|
||||
column_separator: str = _DEFAULT_COLUMN_SEPARATOR,
|
||||
model_dimension: int = 512,
|
||||
) -> List[Tuple[str, Dict[str, Any]]]:
|
||||
"""Get db summary for database.
|
||||
@@ -99,14 +129,21 @@ def _parse_db_summary_with_metadata(
|
||||
tables = conn.get_table_names()
|
||||
table_info_summaries = [
|
||||
_parse_table_summary_with_metadata(
|
||||
conn, summary_template, separator, table_name, model_dimension
|
||||
conn,
|
||||
summary_template,
|
||||
separator,
|
||||
table_name,
|
||||
model_dimension,
|
||||
column_separator=column_separator,
|
||||
)
|
||||
for table_name in tables
|
||||
]
|
||||
return table_info_summaries
|
||||
|
||||
|
||||
def _split_columns_str(columns: List[str], model_dimension: int):
|
||||
def _split_columns_str(
|
||||
columns: List[str], model_dimension: int, column_separator: str = ",\r\n "
|
||||
):
|
||||
"""Split columns str.
|
||||
|
||||
Args:
|
||||
@@ -129,7 +166,7 @@ def _split_columns_str(columns: List[str], model_dimension: int):
|
||||
else:
|
||||
# If current string is empty, add element directly
|
||||
if current_string:
|
||||
current_string += "," + element_str
|
||||
current_string += column_separator + element_str
|
||||
else:
|
||||
current_string = element_str
|
||||
current_length += element_length + 1 # Add length of space
|
||||
@@ -147,6 +184,8 @@ def _parse_table_summary_with_metadata(
|
||||
separator,
|
||||
table_name: str,
|
||||
model_dimension=512,
|
||||
column_separator: str = _DEFAULT_COLUMN_SEPARATOR,
|
||||
db_summary_version: str = "v1.0",
|
||||
) -> Tuple[str, Dict[str, Any]]:
|
||||
"""Get table summary for table.
|
||||
|
||||
@@ -168,17 +207,26 @@ def _parse_table_summary_with_metadata(
|
||||
(column4,comment), (column5, comment), (column6, comment)
|
||||
"""
|
||||
columns = []
|
||||
metadata = {"table_name": table_name, "separated": 0}
|
||||
metadata = {
|
||||
"table_name": table_name,
|
||||
"separated": 0,
|
||||
"db_summary_version": db_summary_version,
|
||||
}
|
||||
for column in conn.get_columns(table_name):
|
||||
if column.get("comment"):
|
||||
columns.append(f"{column['name']} ({column.get('comment')})")
|
||||
else:
|
||||
columns.append(f"{column['name']}")
|
||||
col_name = column["name"]
|
||||
col_type = str(column["type"]) if "type" in column else None
|
||||
col_comment = column.get("comment")
|
||||
column_def = f'"{col_name}" {col_type.upper()}'
|
||||
if col_comment:
|
||||
column_def += f' COMMENT "{col_comment}"'
|
||||
columns.append(column_def)
|
||||
metadata.update({"field_num": len(columns)})
|
||||
separated_columns = _split_columns_str(columns, model_dimension=model_dimension)
|
||||
separated_columns = _split_columns_str(
|
||||
columns, model_dimension=model_dimension, column_separator=column_separator
|
||||
)
|
||||
if len(separated_columns) > 1:
|
||||
metadata["separated"] = 1
|
||||
column_str = "\n".join(separated_columns)
|
||||
column_str = column_separator.join(separated_columns)
|
||||
# Obtain index information
|
||||
index_keys = []
|
||||
raw_indexes = conn.get_indexes(table_name)
|
||||
@@ -193,18 +241,19 @@ def _parse_table_summary_with_metadata(
|
||||
else:
|
||||
key_str = ", ".join(index["column_names"])
|
||||
index_keys.append(f"{index['name']}(`{key_str}`) ")
|
||||
table_str = summary_template.format(table_name=table_name)
|
||||
|
||||
table_comment = ""
|
||||
|
||||
try:
|
||||
comment = conn.get_table_comment(table_name)
|
||||
table_comment = comment.get("text")
|
||||
except Exception:
|
||||
comment = dict(text=None)
|
||||
if comment.get("text"):
|
||||
table_str += f"\ntable_comment: {comment.get('text')}"
|
||||
pass
|
||||
|
||||
if len(index_keys) > 0:
|
||||
index_key_str = ", ".join(index_keys)
|
||||
table_str += f"\nindex_keys: {index_key_str}"
|
||||
index_key_str = ", ".join(index_keys)
|
||||
table_str = summary_template.format(
|
||||
table_name=table_name, table_comment=table_comment, index_keys=index_key_str
|
||||
)
|
||||
table_str += f"\n{separator}\n{column_str}"
|
||||
return table_str, metadata
|
||||
|
||||
|
Reference in New Issue
Block a user