refactor(ChatData):update rdbms db summary (#885)

This commit is contained in:
Aries-ckt
2023-12-04 22:16:28 +08:00
committed by GitHub
parent b12a858d53
commit a2f087c429
3 changed files with 48 additions and 330 deletions

View File

@@ -60,73 +60,6 @@ class ConnectManager:
# self.storage = DuckdbConnectConfig()
self.storage = ConnectConfigDao()
self.db_summary_client = DBSummaryClient(system_app)
# self.__load_config_db()
# def __load_config_db(self):
# if CFG.LOCAL_DB_HOST:
# # default mysql
# if CFG.LOCAL_DB_NAME:
# self.storage.add_url_db(
# CFG.LOCAL_DB_NAME,
# DBType.Mysql.value(),
# CFG.LOCAL_DB_HOST,
# CFG.LOCAL_DB_PORT,
# CFG.LOCAL_DB_USER,
# CFG.LOCAL_DB_PASSWORD,
# "",
# )
# else:
# # get all default mysql database
# default_mysql = Database.from_uri(
# "mysql+pymysql://"
# + CFG.LOCAL_DB_USER
# + ":"
# + CFG.LOCAL_DB_PASSWORD
# + "@"
# + CFG.LOCAL_DB_HOST
# + ":"
# + str(CFG.LOCAL_DB_PORT),
# engine_args={
# "pool_size": CFG.LOCAL_DB_POOL_SIZE,
# "pool_recycle": 3600,
# "echo": True,
# },
# )
# dbs = default_mysql.get_database_list()
# for name in dbs:
# self.storage.add_url_db(
# name,
# DBType.Mysql.value(),
# CFG.LOCAL_DB_HOST,
# CFG.LOCAL_DB_PORT,
# CFG.LOCAL_DB_USER,
# CFG.LOCAL_DB_PASSWORD,
# "",
# )
# db_type = DBType.of_db_type(CFG.LOCAL_DB_TYPE)
# if db_type.is_file_db():
# db_name = CFG.LOCAL_DB_NAME
# db_type = CFG.LOCAL_DB_TYPE
# db_path = CFG.LOCAL_DB_PATH
# if not db_type:
# # Default file database type
# db_type = DBType.DuckDb.value()
# if not db_name:
# db_type, db_name = self._parse_file_db_info(db_type, db_path)
# if db_name:
# print(
# f"Add file db, db_name: {db_name}, db_type: {db_type}, db_path: {db_path}"
# )
# self.storage.add_file_db(db_name, db_type, db_path)
# def _parse_file_db_info(self, db_type: str, db_path: str):
# if db_type is None or db_type == DBType.DuckDb.value():
# # file db is duckdb
# db_name = self.storage.get_file_db_name(db_path)
# db_type = DBType.DuckDb.value()
# else:
# db_name = DBType.parse_file_db_name_from_path(db_type, db_path)
# return db_type, db_name
def get_connect(self, db_name):
db_config = self.storage.get_db_config(db_name)

View File

@@ -2,16 +2,12 @@ import json
import uuid
import logging
from pilot.common.schema import DBType
from pilot.component import SystemApp
from pilot.configs.config import Config
from pilot.configs.model_config import (
KNOWLEDGE_UPLOAD_ROOT_PATH,
EMBEDDING_MODEL_CONFIG,
)
from pilot.scene.base import ChatScene
from pilot.scene.base_chat import BaseChat
from pilot.scene.chat_factory import ChatFactory
from pilot.summary.rdbms_db_summary import RdbmsSummary
@@ -33,7 +29,6 @@ class DBSummaryClient:
def db_summary_embedding(self, dbname, db_type):
"""put db profile and table profile summary into vector store"""
from pilot.embedding_engine.string_embedding import StringEmbedding
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
db_summary_client = RdbmsSummary(dbname, db_type)
@@ -43,48 +38,12 @@ class DBSummaryClient:
embeddings = embedding_factory.create(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
)
vector_store_config = {
"vector_store_name": dbname + "_summary",
"vector_store_type": CFG.VECTOR_STORE_TYPE,
"embeddings": embeddings,
}
embedding = StringEmbedding(
file_path=db_summary_client.get_summary(),
vector_store_config=vector_store_config,
)
self.init_db_profile(db_summary_client, dbname, embeddings)
if not embedding.vector_name_exist():
if CFG.SUMMARY_CONFIG == "FAST":
for vector_table_info in db_summary_client.get_summary():
embedding = StringEmbedding(
vector_table_info,
vector_store_config,
)
embedding.source_embedding()
else:
embedding = StringEmbedding(
file_path=db_summary_client.get_summary(),
vector_store_config=vector_store_config,
)
embedding.source_embedding()
for (
table_name,
table_summary,
) in db_summary_client.get_table_summary().items():
table_vector_store_config = {
"vector_store_name": dbname + "_" + table_name + "_ts",
"vector_store_type": CFG.VECTOR_STORE_TYPE,
"embeddings": embeddings,
}
embedding = StringEmbedding(
table_summary,
table_vector_store_config,
)
embedding.source_embedding()
logger.info("db summary embedding success")
def get_db_summary(self, dbname, query, topk):
"""get user query related tables info"""
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
@@ -104,53 +63,8 @@ class DBSummaryClient:
ans = [d.page_content for d in table_docs]
return ans
def get_similar_tables(self, dbname, query, topk):
"""get user query related tables info"""
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
vector_store_config = {
"vector_store_name": dbname + "_summary",
"vector_store_type": CFG.VECTOR_STORE_TYPE,
}
embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory
)
knowledge_embedding_client = EmbeddingEngine(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config=vector_store_config,
embedding_factory=embedding_factory,
)
if CFG.SUMMARY_CONFIG == "FAST":
table_docs = knowledge_embedding_client.similar_search(query, topk)
related_tables = [
json.loads(table_doc.page_content)["table_name"]
for table_doc in table_docs
]
else:
table_docs = knowledge_embedding_client.similar_search(query, 1)
# prompt = KnownLedgeBaseQA.build_db_summary_prompt(
# query, table_docs[0].page_content
# )
related_tables = _get_llm_response(
query, dbname, table_docs[0].page_content
)
related_table_summaries = []
for table in related_tables:
vector_store_config = {
"vector_store_name": dbname + "_" + table + "_ts",
"vector_store_type": CFG.VECTOR_STORE_TYPE,
}
knowledge_embedding_client = EmbeddingEngine(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config=vector_store_config,
embedding_factory=embedding_factory,
)
table_summary = knowledge_embedding_client.similar_search(query, 1)
related_table_summaries.append(table_summary[0].page_content)
return related_table_summaries
def init_db_summary(self):
"""init db summary"""
db_mange = CFG.LOCAL_DB_MANAGE
dbs = db_mange.get_db_list()
for item in dbs:
@@ -177,17 +91,16 @@ class DBSummaryClient:
"embeddings": embeddings,
}
embedding = StringEmbedding(
file_path=db_summary_client.get_db_summary(),
file_path=None,
vector_store_config=profile_store_config,
)
if not embedding.vector_name_exist():
docs = []
docs.extend(embedding.read_batch())
for table_summary in db_summary_client.table_info_json():
for table_summary in db_summary_client.table_summaries():
from langchain.text_splitter import RecursiveCharacterTextSplitter
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=len(table_summary), chunk_overlap=100
chunk_size=len(table_summary), chunk_overlap=0
)
embedding = StringEmbedding(
file_path=table_summary,
@@ -195,23 +108,8 @@ class DBSummaryClient:
text_splitter=text_splitter,
)
docs.extend(embedding.read_batch())
embedding.index_to_store(docs)
if len(docs) > 0:
embedding.index_to_store(docs)
else:
logger.info(f"Vector store name {vector_store_name} exist")
logger.info("init db profile success...")
def _get_llm_response(query, db_input, dbsummary):
chat_param = {
"temperature": 0.7,
"max_new_tokens": 512,
"chat_session_id": uuid.uuid1(),
"user_input": query,
"db_select": db_input,
"db_summary": dbsummary,
}
chat: BaseChat = chat_factory.get_implementation(
ChatScene.InnerChatDBSummary.value, **chat_param
)
res = chat._blocking_nostream_call()
return json.loads(res)["table"]
logger.info("initialize db summary profile success...")

View File

@@ -1,22 +1,22 @@
import json
from pilot.configs.config import Config
from pilot.summary.db_summary import DBSummary, TableSummary, FieldSummary, IndexSummary
from pilot.summary.db_summary import DBSummary
CFG = Config()
class RdbmsSummary(DBSummary):
"""Get mysql summary template."""
"""Get rdbms db table summary template.
summary example:
table_name(column1(column1 comment),column2(column2 comment),column3(column3 comment) and index keys, and table comment is {table_comment})
"""
def __init__(self, name, type):
self.name = name
self.type = type
self.summary = """{{"database_name": "{name}", "type": "{type}", "tables": "{tables}", "qps": "{qps}", "tps": {tps}}}"""
self.summary_template = "{table_name}({columns})"
self.tables = {}
self.tables_info = []
self.vector_tables_info = []
# self.tables_summary = {}
self.db = CFG.LOCAL_DB_MANAGE.get_connect(name)
@@ -27,154 +27,41 @@ class RdbmsSummary(DBSummary):
collation=self.db.get_collation(),
)
tables = self.db.get_table_names()
self.table_comments = self.db.get_table_comments(name)
comment_map = {}
for table_comment in self.table_comments:
self.tables_info.append(
"table name:{table_name},table description:{table_comment}".format(
table_name=table_comment[0], table_comment=table_comment[1]
)
)
comment_map[table_comment[0]] = table_comment[1]
self.table_info_summaries = [
self.get_table_summary(table_name) for table_name in tables
]
vector_table = json.dumps(
{"table_name": table_comment[0], "table_description": table_comment[1]}
)
self.vector_tables_info.append(
vector_table.encode("utf-8").decode("unicode_escape")
)
self.table_columns_info = []
self.table_columns_json = []
def get_table_summary(self, table_name):
"""Get table summary for table.
example:
table_name(column1(column1 comment),column2(column2 comment),column3(column3 comment) and index keys, and table comment: {table_comment})
"""
columns = []
for column in self.db._inspector.get_columns(table_name):
if column.get("comment"):
columns.append((f"{column['name']} ({column.get('comment')})"))
else:
columns.append(f"{column['name']}")
for table_name in tables:
table_summary = RdbmsTableSummary(self.db, name, table_name, comment_map)
# self.tables[table_name] = table_summary.get_summary()
self.tables[table_name] = table_summary.get_columns()
self.table_columns_info.append(table_summary.get_columns())
# self.table_columns_json.append(table_summary.get_summary_json())
table_profile = (
"table name:{table_name},table description:{table_comment}".format(
table_name=table_name,
table_comment=self.db.get_show_create_table(table_name),
)
)
self.table_columns_json.append(table_profile)
# self.tables_info.append(table_summary.get_summary())
def get_summary(self):
if CFG.SUMMARY_CONFIG == "FAST":
return self.vector_tables_info
else:
return self.summary.format(
name=self.name, type=self.type, table_info=";".join(self.tables_info)
)
def get_db_summary(self):
return self.summary.format(
name=self.name,
type=self.type,
tables=";".join(self.vector_tables_info),
qps=1000,
tps=1000,
column_str = ", ".join(columns)
index_keys = []
for index_key in self.db._inspector.get_indexes(table_name):
key_str = ", ".join(index_key["column_names"])
index_keys.append(f"{index_key['name']}(`{key_str}`) ")
table_str = self.summary_template.format(
table_name=table_name, columns=column_str
)
if len(index_keys) > 0:
index_key_str = ", ".join(index_keys)
table_str += f", and index keys: {index_key_str}"
try:
comment = self.db._inspector.get_table_comment(table_name)
except Exception:
comment = dict(text=None)
if comment.get("text"):
table_str += f", and table comment: {comment.get('text')}"
return table_str
def get_table_summary(self):
return self.tables
def get_table_comments(self):
return self.table_comments
def table_info_json(self):
return self.table_columns_json
class RdbmsTableSummary(TableSummary):
"""Get mysql table summary template."""
def __init__(self, instance, dbname, name, comment_map):
self.name = name
self.dbname = dbname
self.summary = """database name:{dbname}, table name:{name}, have columns info: {fields}, have indexes info: {indexes}"""
self.json_summary_template = """{{"table_name": "{name}", "comment": "{comment}", "columns": "{fields}", "indexes": "{indexes}", "size_in_bytes": {size_in_bytes}, "rows": {rows}}}"""
self.fields = []
self.fields_info = []
self.indexes = []
self.indexes_info = []
self.db = instance
fields = self.db.get_fields(name)
indexes = self.db.get_indexes(name)
field_names = []
for field in fields:
field_summary = RdbmsFieldsSummary(field)
self.fields.append(field_summary)
self.fields_info.append(field_summary.get_summary())
field_names.append(field[0])
self.column_summary = """{name}({columns_info})""".format(
name=name, columns_info=",".join(field_names)
)
for index in indexes:
index_summary = RdbmsIndexSummary(index)
self.indexes.append(index_summary)
self.indexes_info.append(index_summary.get_summary())
self.json_summary = self.json_summary_template.format(
name=name,
comment=comment_map[name],
fields=self.fields_info,
indexes=self.indexes_info,
size_in_bytes=1000,
rows=1000,
)
def get_summary(self):
return self.summary.format(
name=self.name,
dbname=self.dbname,
fields=";".join(self.fields_info),
indexes=";".join(self.indexes_info),
)
def get_columns(self):
return self.column_summary
def get_summary_json(self):
return self.json_summary
class RdbmsFieldsSummary(FieldSummary):
"""Get mysql field summary template."""
def __init__(self, field):
self.name = field[0]
# self.summary = """column name:{name}, column data type:{data_type}, is nullable:{is_nullable}, default value is:{default_value}, comment is:{comment} """
# self.summary = """{"name": {name}, "type": {data_type}, "is_primary_key": {is_nullable}, "comment":{comment}, "default":{default_value}}"""
self.data_type = field[1]
self.default_value = field[2]
self.is_nullable = field[3]
self.comment = field[4]
def get_summary(self):
return '{{"name": "{name}", "type": "{data_type}", "is_primary_key": "{is_nullable}", "comment": "{comment}", "default": "{default_value}"}}'.format(
name=self.name,
data_type=self.data_type,
is_nullable=self.is_nullable,
default_value=self.default_value,
comment=self.comment,
)
class RdbmsIndexSummary(IndexSummary):
"""Get mysql index summary template."""
def __init__(self, index):
self.name = index[0]
# self.summary = """index name:{name}, index bind columns:{bind_fields}"""
self.summary_template = '{{"name": "{name}", "columns": {bind_fields}}}'
self.bind_fields = index[1]
def get_summary(self):
return self.summary_template.format(
name=self.name, bind_fields=self.bind_fields
)
def table_summaries(self):
"""Get table summaries."""
return self.table_info_summaries