mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-13 21:21:08 +00:00
refactor(ChatData):update rdbms db summary (#885)
This commit is contained in:
@@ -60,73 +60,6 @@ class ConnectManager:
|
||||
# self.storage = DuckdbConnectConfig()
|
||||
self.storage = ConnectConfigDao()
|
||||
self.db_summary_client = DBSummaryClient(system_app)
|
||||
# self.__load_config_db()
|
||||
|
||||
# def __load_config_db(self):
|
||||
# if CFG.LOCAL_DB_HOST:
|
||||
# # default mysql
|
||||
# if CFG.LOCAL_DB_NAME:
|
||||
# self.storage.add_url_db(
|
||||
# CFG.LOCAL_DB_NAME,
|
||||
# DBType.Mysql.value(),
|
||||
# CFG.LOCAL_DB_HOST,
|
||||
# CFG.LOCAL_DB_PORT,
|
||||
# CFG.LOCAL_DB_USER,
|
||||
# CFG.LOCAL_DB_PASSWORD,
|
||||
# "",
|
||||
# )
|
||||
# else:
|
||||
# # get all default mysql database
|
||||
# default_mysql = Database.from_uri(
|
||||
# "mysql+pymysql://"
|
||||
# + CFG.LOCAL_DB_USER
|
||||
# + ":"
|
||||
# + CFG.LOCAL_DB_PASSWORD
|
||||
# + "@"
|
||||
# + CFG.LOCAL_DB_HOST
|
||||
# + ":"
|
||||
# + str(CFG.LOCAL_DB_PORT),
|
||||
# engine_args={
|
||||
# "pool_size": CFG.LOCAL_DB_POOL_SIZE,
|
||||
# "pool_recycle": 3600,
|
||||
# "echo": True,
|
||||
# },
|
||||
# )
|
||||
# dbs = default_mysql.get_database_list()
|
||||
# for name in dbs:
|
||||
# self.storage.add_url_db(
|
||||
# name,
|
||||
# DBType.Mysql.value(),
|
||||
# CFG.LOCAL_DB_HOST,
|
||||
# CFG.LOCAL_DB_PORT,
|
||||
# CFG.LOCAL_DB_USER,
|
||||
# CFG.LOCAL_DB_PASSWORD,
|
||||
# "",
|
||||
# )
|
||||
# db_type = DBType.of_db_type(CFG.LOCAL_DB_TYPE)
|
||||
# if db_type.is_file_db():
|
||||
# db_name = CFG.LOCAL_DB_NAME
|
||||
# db_type = CFG.LOCAL_DB_TYPE
|
||||
# db_path = CFG.LOCAL_DB_PATH
|
||||
# if not db_type:
|
||||
# # Default file database type
|
||||
# db_type = DBType.DuckDb.value()
|
||||
# if not db_name:
|
||||
# db_type, db_name = self._parse_file_db_info(db_type, db_path)
|
||||
# if db_name:
|
||||
# print(
|
||||
# f"Add file db, db_name: {db_name}, db_type: {db_type}, db_path: {db_path}"
|
||||
# )
|
||||
# self.storage.add_file_db(db_name, db_type, db_path)
|
||||
|
||||
# def _parse_file_db_info(self, db_type: str, db_path: str):
|
||||
# if db_type is None or db_type == DBType.DuckDb.value():
|
||||
# # file db is duckdb
|
||||
# db_name = self.storage.get_file_db_name(db_path)
|
||||
# db_type = DBType.DuckDb.value()
|
||||
# else:
|
||||
# db_name = DBType.parse_file_db_name_from_path(db_type, db_path)
|
||||
# return db_type, db_name
|
||||
|
||||
def get_connect(self, db_name):
|
||||
db_config = self.storage.get_db_config(db_name)
|
||||
|
@@ -2,16 +2,12 @@ import json
|
||||
import uuid
|
||||
import logging
|
||||
|
||||
from pilot.common.schema import DBType
|
||||
from pilot.component import SystemApp
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import (
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
EMBEDDING_MODEL_CONFIG,
|
||||
)
|
||||
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
from pilot.scene.chat_factory import ChatFactory
|
||||
from pilot.summary.rdbms_db_summary import RdbmsSummary
|
||||
|
||||
@@ -33,7 +29,6 @@ class DBSummaryClient:
|
||||
|
||||
def db_summary_embedding(self, dbname, db_type):
|
||||
"""put db profile and table profile summary into vector store"""
|
||||
from pilot.embedding_engine.string_embedding import StringEmbedding
|
||||
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
|
||||
|
||||
db_summary_client = RdbmsSummary(dbname, db_type)
|
||||
@@ -43,48 +38,12 @@ class DBSummaryClient:
|
||||
embeddings = embedding_factory.create(
|
||||
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
|
||||
)
|
||||
vector_store_config = {
|
||||
"vector_store_name": dbname + "_summary",
|
||||
"vector_store_type": CFG.VECTOR_STORE_TYPE,
|
||||
"embeddings": embeddings,
|
||||
}
|
||||
embedding = StringEmbedding(
|
||||
file_path=db_summary_client.get_summary(),
|
||||
vector_store_config=vector_store_config,
|
||||
)
|
||||
self.init_db_profile(db_summary_client, dbname, embeddings)
|
||||
if not embedding.vector_name_exist():
|
||||
if CFG.SUMMARY_CONFIG == "FAST":
|
||||
for vector_table_info in db_summary_client.get_summary():
|
||||
embedding = StringEmbedding(
|
||||
vector_table_info,
|
||||
vector_store_config,
|
||||
)
|
||||
embedding.source_embedding()
|
||||
else:
|
||||
embedding = StringEmbedding(
|
||||
file_path=db_summary_client.get_summary(),
|
||||
vector_store_config=vector_store_config,
|
||||
)
|
||||
embedding.source_embedding()
|
||||
for (
|
||||
table_name,
|
||||
table_summary,
|
||||
) in db_summary_client.get_table_summary().items():
|
||||
table_vector_store_config = {
|
||||
"vector_store_name": dbname + "_" + table_name + "_ts",
|
||||
"vector_store_type": CFG.VECTOR_STORE_TYPE,
|
||||
"embeddings": embeddings,
|
||||
}
|
||||
embedding = StringEmbedding(
|
||||
table_summary,
|
||||
table_vector_store_config,
|
||||
)
|
||||
embedding.source_embedding()
|
||||
|
||||
logger.info("db summary embedding success")
|
||||
|
||||
def get_db_summary(self, dbname, query, topk):
|
||||
"""get user query related tables info"""
|
||||
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
|
||||
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
|
||||
|
||||
@@ -104,53 +63,8 @@ class DBSummaryClient:
|
||||
ans = [d.page_content for d in table_docs]
|
||||
return ans
|
||||
|
||||
def get_similar_tables(self, dbname, query, topk):
|
||||
"""get user query related tables info"""
|
||||
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
|
||||
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
|
||||
|
||||
vector_store_config = {
|
||||
"vector_store_name": dbname + "_summary",
|
||||
"vector_store_type": CFG.VECTOR_STORE_TYPE,
|
||||
}
|
||||
embedding_factory = CFG.SYSTEM_APP.get_component(
|
||||
"embedding_factory", EmbeddingFactory
|
||||
)
|
||||
knowledge_embedding_client = EmbeddingEngine(
|
||||
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
vector_store_config=vector_store_config,
|
||||
embedding_factory=embedding_factory,
|
||||
)
|
||||
if CFG.SUMMARY_CONFIG == "FAST":
|
||||
table_docs = knowledge_embedding_client.similar_search(query, topk)
|
||||
related_tables = [
|
||||
json.loads(table_doc.page_content)["table_name"]
|
||||
for table_doc in table_docs
|
||||
]
|
||||
else:
|
||||
table_docs = knowledge_embedding_client.similar_search(query, 1)
|
||||
# prompt = KnownLedgeBaseQA.build_db_summary_prompt(
|
||||
# query, table_docs[0].page_content
|
||||
# )
|
||||
related_tables = _get_llm_response(
|
||||
query, dbname, table_docs[0].page_content
|
||||
)
|
||||
related_table_summaries = []
|
||||
for table in related_tables:
|
||||
vector_store_config = {
|
||||
"vector_store_name": dbname + "_" + table + "_ts",
|
||||
"vector_store_type": CFG.VECTOR_STORE_TYPE,
|
||||
}
|
||||
knowledge_embedding_client = EmbeddingEngine(
|
||||
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
vector_store_config=vector_store_config,
|
||||
embedding_factory=embedding_factory,
|
||||
)
|
||||
table_summary = knowledge_embedding_client.similar_search(query, 1)
|
||||
related_table_summaries.append(table_summary[0].page_content)
|
||||
return related_table_summaries
|
||||
|
||||
def init_db_summary(self):
|
||||
"""init db summary"""
|
||||
db_mange = CFG.LOCAL_DB_MANAGE
|
||||
dbs = db_mange.get_db_list()
|
||||
for item in dbs:
|
||||
@@ -177,17 +91,16 @@ class DBSummaryClient:
|
||||
"embeddings": embeddings,
|
||||
}
|
||||
embedding = StringEmbedding(
|
||||
file_path=db_summary_client.get_db_summary(),
|
||||
file_path=None,
|
||||
vector_store_config=profile_store_config,
|
||||
)
|
||||
if not embedding.vector_name_exist():
|
||||
docs = []
|
||||
docs.extend(embedding.read_batch())
|
||||
for table_summary in db_summary_client.table_info_json():
|
||||
for table_summary in db_summary_client.table_summaries():
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=len(table_summary), chunk_overlap=100
|
||||
chunk_size=len(table_summary), chunk_overlap=0
|
||||
)
|
||||
embedding = StringEmbedding(
|
||||
file_path=table_summary,
|
||||
@@ -195,23 +108,8 @@ class DBSummaryClient:
|
||||
text_splitter=text_splitter,
|
||||
)
|
||||
docs.extend(embedding.read_batch())
|
||||
embedding.index_to_store(docs)
|
||||
if len(docs) > 0:
|
||||
embedding.index_to_store(docs)
|
||||
else:
|
||||
logger.info(f"Vector store name {vector_store_name} exist")
|
||||
logger.info("init db profile success...")
|
||||
|
||||
|
||||
def _get_llm_response(query, db_input, dbsummary):
|
||||
chat_param = {
|
||||
"temperature": 0.7,
|
||||
"max_new_tokens": 512,
|
||||
"chat_session_id": uuid.uuid1(),
|
||||
"user_input": query,
|
||||
"db_select": db_input,
|
||||
"db_summary": dbsummary,
|
||||
}
|
||||
chat: BaseChat = chat_factory.get_implementation(
|
||||
ChatScene.InnerChatDBSummary.value, **chat_param
|
||||
)
|
||||
res = chat._blocking_nostream_call()
|
||||
return json.loads(res)["table"]
|
||||
logger.info("initialize db summary profile success...")
|
||||
|
@@ -1,22 +1,22 @@
|
||||
import json
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.summary.db_summary import DBSummary, TableSummary, FieldSummary, IndexSummary
|
||||
from pilot.summary.db_summary import DBSummary
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class RdbmsSummary(DBSummary):
|
||||
"""Get mysql summary template."""
|
||||
"""Get rdbms db table summary template.
|
||||
summary example:
|
||||
table_name(column1(column1 comment),column2(column2 comment),column3(column3 comment) and index keys, and table comment is {table_comment})
|
||||
"""
|
||||
|
||||
def __init__(self, name, type):
|
||||
self.name = name
|
||||
self.type = type
|
||||
self.summary = """{{"database_name": "{name}", "type": "{type}", "tables": "{tables}", "qps": "{qps}", "tps": {tps}}}"""
|
||||
self.summary_template = "{table_name}({columns})"
|
||||
self.tables = {}
|
||||
self.tables_info = []
|
||||
self.vector_tables_info = []
|
||||
# self.tables_summary = {}
|
||||
|
||||
self.db = CFG.LOCAL_DB_MANAGE.get_connect(name)
|
||||
|
||||
@@ -27,154 +27,41 @@ class RdbmsSummary(DBSummary):
|
||||
collation=self.db.get_collation(),
|
||||
)
|
||||
tables = self.db.get_table_names()
|
||||
self.table_comments = self.db.get_table_comments(name)
|
||||
comment_map = {}
|
||||
for table_comment in self.table_comments:
|
||||
self.tables_info.append(
|
||||
"table name:{table_name},table description:{table_comment}".format(
|
||||
table_name=table_comment[0], table_comment=table_comment[1]
|
||||
)
|
||||
)
|
||||
comment_map[table_comment[0]] = table_comment[1]
|
||||
self.table_info_summaries = [
|
||||
self.get_table_summary(table_name) for table_name in tables
|
||||
]
|
||||
|
||||
vector_table = json.dumps(
|
||||
{"table_name": table_comment[0], "table_description": table_comment[1]}
|
||||
)
|
||||
self.vector_tables_info.append(
|
||||
vector_table.encode("utf-8").decode("unicode_escape")
|
||||
)
|
||||
self.table_columns_info = []
|
||||
self.table_columns_json = []
|
||||
def get_table_summary(self, table_name):
|
||||
"""Get table summary for table.
|
||||
example:
|
||||
table_name(column1(column1 comment),column2(column2 comment),column3(column3 comment) and index keys, and table comment: {table_comment})
|
||||
"""
|
||||
columns = []
|
||||
for column in self.db._inspector.get_columns(table_name):
|
||||
if column.get("comment"):
|
||||
columns.append((f"{column['name']} ({column.get('comment')})"))
|
||||
else:
|
||||
columns.append(f"{column['name']}")
|
||||
|
||||
for table_name in tables:
|
||||
table_summary = RdbmsTableSummary(self.db, name, table_name, comment_map)
|
||||
# self.tables[table_name] = table_summary.get_summary()
|
||||
self.tables[table_name] = table_summary.get_columns()
|
||||
self.table_columns_info.append(table_summary.get_columns())
|
||||
# self.table_columns_json.append(table_summary.get_summary_json())
|
||||
table_profile = (
|
||||
"table name:{table_name},table description:{table_comment}".format(
|
||||
table_name=table_name,
|
||||
table_comment=self.db.get_show_create_table(table_name),
|
||||
)
|
||||
)
|
||||
self.table_columns_json.append(table_profile)
|
||||
# self.tables_info.append(table_summary.get_summary())
|
||||
|
||||
def get_summary(self):
|
||||
if CFG.SUMMARY_CONFIG == "FAST":
|
||||
return self.vector_tables_info
|
||||
else:
|
||||
return self.summary.format(
|
||||
name=self.name, type=self.type, table_info=";".join(self.tables_info)
|
||||
)
|
||||
|
||||
def get_db_summary(self):
|
||||
return self.summary.format(
|
||||
name=self.name,
|
||||
type=self.type,
|
||||
tables=";".join(self.vector_tables_info),
|
||||
qps=1000,
|
||||
tps=1000,
|
||||
column_str = ", ".join(columns)
|
||||
index_keys = []
|
||||
for index_key in self.db._inspector.get_indexes(table_name):
|
||||
key_str = ", ".join(index_key["column_names"])
|
||||
index_keys.append(f"{index_key['name']}(`{key_str}`) ")
|
||||
table_str = self.summary_template.format(
|
||||
table_name=table_name, columns=column_str
|
||||
)
|
||||
if len(index_keys) > 0:
|
||||
index_key_str = ", ".join(index_keys)
|
||||
table_str += f", and index keys: {index_key_str}"
|
||||
try:
|
||||
comment = self.db._inspector.get_table_comment(table_name)
|
||||
except Exception:
|
||||
comment = dict(text=None)
|
||||
if comment.get("text"):
|
||||
table_str += f", and table comment: {comment.get('text')}"
|
||||
return table_str
|
||||
|
||||
def get_table_summary(self):
|
||||
return self.tables
|
||||
|
||||
def get_table_comments(self):
|
||||
return self.table_comments
|
||||
|
||||
def table_info_json(self):
|
||||
return self.table_columns_json
|
||||
|
||||
|
||||
class RdbmsTableSummary(TableSummary):
|
||||
"""Get mysql table summary template."""
|
||||
|
||||
def __init__(self, instance, dbname, name, comment_map):
|
||||
self.name = name
|
||||
self.dbname = dbname
|
||||
self.summary = """database name:{dbname}, table name:{name}, have columns info: {fields}, have indexes info: {indexes}"""
|
||||
self.json_summary_template = """{{"table_name": "{name}", "comment": "{comment}", "columns": "{fields}", "indexes": "{indexes}", "size_in_bytes": {size_in_bytes}, "rows": {rows}}}"""
|
||||
self.fields = []
|
||||
self.fields_info = []
|
||||
self.indexes = []
|
||||
self.indexes_info = []
|
||||
self.db = instance
|
||||
fields = self.db.get_fields(name)
|
||||
indexes = self.db.get_indexes(name)
|
||||
field_names = []
|
||||
for field in fields:
|
||||
field_summary = RdbmsFieldsSummary(field)
|
||||
self.fields.append(field_summary)
|
||||
self.fields_info.append(field_summary.get_summary())
|
||||
field_names.append(field[0])
|
||||
|
||||
self.column_summary = """{name}({columns_info})""".format(
|
||||
name=name, columns_info=",".join(field_names)
|
||||
)
|
||||
|
||||
for index in indexes:
|
||||
index_summary = RdbmsIndexSummary(index)
|
||||
self.indexes.append(index_summary)
|
||||
self.indexes_info.append(index_summary.get_summary())
|
||||
|
||||
self.json_summary = self.json_summary_template.format(
|
||||
name=name,
|
||||
comment=comment_map[name],
|
||||
fields=self.fields_info,
|
||||
indexes=self.indexes_info,
|
||||
size_in_bytes=1000,
|
||||
rows=1000,
|
||||
)
|
||||
|
||||
def get_summary(self):
|
||||
return self.summary.format(
|
||||
name=self.name,
|
||||
dbname=self.dbname,
|
||||
fields=";".join(self.fields_info),
|
||||
indexes=";".join(self.indexes_info),
|
||||
)
|
||||
|
||||
def get_columns(self):
|
||||
return self.column_summary
|
||||
|
||||
def get_summary_json(self):
|
||||
return self.json_summary
|
||||
|
||||
|
||||
class RdbmsFieldsSummary(FieldSummary):
|
||||
"""Get mysql field summary template."""
|
||||
|
||||
def __init__(self, field):
|
||||
self.name = field[0]
|
||||
# self.summary = """column name:{name}, column data type:{data_type}, is nullable:{is_nullable}, default value is:{default_value}, comment is:{comment} """
|
||||
# self.summary = """{"name": {name}, "type": {data_type}, "is_primary_key": {is_nullable}, "comment":{comment}, "default":{default_value}}"""
|
||||
self.data_type = field[1]
|
||||
self.default_value = field[2]
|
||||
self.is_nullable = field[3]
|
||||
self.comment = field[4]
|
||||
|
||||
def get_summary(self):
|
||||
return '{{"name": "{name}", "type": "{data_type}", "is_primary_key": "{is_nullable}", "comment": "{comment}", "default": "{default_value}"}}'.format(
|
||||
name=self.name,
|
||||
data_type=self.data_type,
|
||||
is_nullable=self.is_nullable,
|
||||
default_value=self.default_value,
|
||||
comment=self.comment,
|
||||
)
|
||||
|
||||
|
||||
class RdbmsIndexSummary(IndexSummary):
|
||||
"""Get mysql index summary template."""
|
||||
|
||||
def __init__(self, index):
|
||||
self.name = index[0]
|
||||
# self.summary = """index name:{name}, index bind columns:{bind_fields}"""
|
||||
self.summary_template = '{{"name": "{name}", "columns": {bind_fields}}}'
|
||||
self.bind_fields = index[1]
|
||||
|
||||
def get_summary(self):
|
||||
return self.summary_template.format(
|
||||
name=self.name, bind_fields=self.bind_fields
|
||||
)
|
||||
def table_summaries(self):
|
||||
"""Get table summaries."""
|
||||
return self.table_info_summaries
|
||||
|
Reference in New Issue
Block a user