refactor(ChatData):update rdbms db summary (#885)

This commit is contained in:
Aries-ckt
2023-12-04 22:16:28 +08:00
committed by GitHub
parent b12a858d53
commit a2f087c429
3 changed files with 48 additions and 330 deletions

View File

@@ -60,73 +60,6 @@ class ConnectManager:
# self.storage = DuckdbConnectConfig() # self.storage = DuckdbConnectConfig()
self.storage = ConnectConfigDao() self.storage = ConnectConfigDao()
self.db_summary_client = DBSummaryClient(system_app) self.db_summary_client = DBSummaryClient(system_app)
# self.__load_config_db()
# def __load_config_db(self):
# if CFG.LOCAL_DB_HOST:
# # default mysql
# if CFG.LOCAL_DB_NAME:
# self.storage.add_url_db(
# CFG.LOCAL_DB_NAME,
# DBType.Mysql.value(),
# CFG.LOCAL_DB_HOST,
# CFG.LOCAL_DB_PORT,
# CFG.LOCAL_DB_USER,
# CFG.LOCAL_DB_PASSWORD,
# "",
# )
# else:
# # get all default mysql database
# default_mysql = Database.from_uri(
# "mysql+pymysql://"
# + CFG.LOCAL_DB_USER
# + ":"
# + CFG.LOCAL_DB_PASSWORD
# + "@"
# + CFG.LOCAL_DB_HOST
# + ":"
# + str(CFG.LOCAL_DB_PORT),
# engine_args={
# "pool_size": CFG.LOCAL_DB_POOL_SIZE,
# "pool_recycle": 3600,
# "echo": True,
# },
# )
# dbs = default_mysql.get_database_list()
# for name in dbs:
# self.storage.add_url_db(
# name,
# DBType.Mysql.value(),
# CFG.LOCAL_DB_HOST,
# CFG.LOCAL_DB_PORT,
# CFG.LOCAL_DB_USER,
# CFG.LOCAL_DB_PASSWORD,
# "",
# )
# db_type = DBType.of_db_type(CFG.LOCAL_DB_TYPE)
# if db_type.is_file_db():
# db_name = CFG.LOCAL_DB_NAME
# db_type = CFG.LOCAL_DB_TYPE
# db_path = CFG.LOCAL_DB_PATH
# if not db_type:
# # Default file database type
# db_type = DBType.DuckDb.value()
# if not db_name:
# db_type, db_name = self._parse_file_db_info(db_type, db_path)
# if db_name:
# print(
# f"Add file db, db_name: {db_name}, db_type: {db_type}, db_path: {db_path}"
# )
# self.storage.add_file_db(db_name, db_type, db_path)
# def _parse_file_db_info(self, db_type: str, db_path: str):
# if db_type is None or db_type == DBType.DuckDb.value():
# # file db is duckdb
# db_name = self.storage.get_file_db_name(db_path)
# db_type = DBType.DuckDb.value()
# else:
# db_name = DBType.parse_file_db_name_from_path(db_type, db_path)
# return db_type, db_name
def get_connect(self, db_name): def get_connect(self, db_name):
db_config = self.storage.get_db_config(db_name) db_config = self.storage.get_db_config(db_name)

View File

@@ -2,16 +2,12 @@ import json
import uuid import uuid
import logging import logging
from pilot.common.schema import DBType
from pilot.component import SystemApp from pilot.component import SystemApp
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.configs.model_config import ( from pilot.configs.model_config import (
KNOWLEDGE_UPLOAD_ROOT_PATH,
EMBEDDING_MODEL_CONFIG, EMBEDDING_MODEL_CONFIG,
) )
from pilot.scene.base import ChatScene
from pilot.scene.base_chat import BaseChat
from pilot.scene.chat_factory import ChatFactory from pilot.scene.chat_factory import ChatFactory
from pilot.summary.rdbms_db_summary import RdbmsSummary from pilot.summary.rdbms_db_summary import RdbmsSummary
@@ -33,7 +29,6 @@ class DBSummaryClient:
def db_summary_embedding(self, dbname, db_type): def db_summary_embedding(self, dbname, db_type):
"""put db profile and table profile summary into vector store""" """put db profile and table profile summary into vector store"""
from pilot.embedding_engine.string_embedding import StringEmbedding
from pilot.embedding_engine.embedding_factory import EmbeddingFactory from pilot.embedding_engine.embedding_factory import EmbeddingFactory
db_summary_client = RdbmsSummary(dbname, db_type) db_summary_client = RdbmsSummary(dbname, db_type)
@@ -43,48 +38,12 @@ class DBSummaryClient:
embeddings = embedding_factory.create( embeddings = embedding_factory.create(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL] model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
) )
vector_store_config = {
"vector_store_name": dbname + "_summary",
"vector_store_type": CFG.VECTOR_STORE_TYPE,
"embeddings": embeddings,
}
embedding = StringEmbedding(
file_path=db_summary_client.get_summary(),
vector_store_config=vector_store_config,
)
self.init_db_profile(db_summary_client, dbname, embeddings) self.init_db_profile(db_summary_client, dbname, embeddings)
if not embedding.vector_name_exist():
if CFG.SUMMARY_CONFIG == "FAST":
for vector_table_info in db_summary_client.get_summary():
embedding = StringEmbedding(
vector_table_info,
vector_store_config,
)
embedding.source_embedding()
else:
embedding = StringEmbedding(
file_path=db_summary_client.get_summary(),
vector_store_config=vector_store_config,
)
embedding.source_embedding()
for (
table_name,
table_summary,
) in db_summary_client.get_table_summary().items():
table_vector_store_config = {
"vector_store_name": dbname + "_" + table_name + "_ts",
"vector_store_type": CFG.VECTOR_STORE_TYPE,
"embeddings": embeddings,
}
embedding = StringEmbedding(
table_summary,
table_vector_store_config,
)
embedding.source_embedding()
logger.info("db summary embedding success") logger.info("db summary embedding success")
def get_db_summary(self, dbname, query, topk): def get_db_summary(self, dbname, query, topk):
"""get user query related tables info"""
from pilot.embedding_engine.embedding_engine import EmbeddingEngine from pilot.embedding_engine.embedding_engine import EmbeddingEngine
from pilot.embedding_engine.embedding_factory import EmbeddingFactory from pilot.embedding_engine.embedding_factory import EmbeddingFactory
@@ -104,53 +63,8 @@ class DBSummaryClient:
ans = [d.page_content for d in table_docs] ans = [d.page_content for d in table_docs]
return ans return ans
def get_similar_tables(self, dbname, query, topk):
"""get user query related tables info"""
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
vector_store_config = {
"vector_store_name": dbname + "_summary",
"vector_store_type": CFG.VECTOR_STORE_TYPE,
}
embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory
)
knowledge_embedding_client = EmbeddingEngine(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config=vector_store_config,
embedding_factory=embedding_factory,
)
if CFG.SUMMARY_CONFIG == "FAST":
table_docs = knowledge_embedding_client.similar_search(query, topk)
related_tables = [
json.loads(table_doc.page_content)["table_name"]
for table_doc in table_docs
]
else:
table_docs = knowledge_embedding_client.similar_search(query, 1)
# prompt = KnownLedgeBaseQA.build_db_summary_prompt(
# query, table_docs[0].page_content
# )
related_tables = _get_llm_response(
query, dbname, table_docs[0].page_content
)
related_table_summaries = []
for table in related_tables:
vector_store_config = {
"vector_store_name": dbname + "_" + table + "_ts",
"vector_store_type": CFG.VECTOR_STORE_TYPE,
}
knowledge_embedding_client = EmbeddingEngine(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config=vector_store_config,
embedding_factory=embedding_factory,
)
table_summary = knowledge_embedding_client.similar_search(query, 1)
related_table_summaries.append(table_summary[0].page_content)
return related_table_summaries
def init_db_summary(self): def init_db_summary(self):
"""init db summary"""
db_mange = CFG.LOCAL_DB_MANAGE db_mange = CFG.LOCAL_DB_MANAGE
dbs = db_mange.get_db_list() dbs = db_mange.get_db_list()
for item in dbs: for item in dbs:
@@ -177,17 +91,16 @@ class DBSummaryClient:
"embeddings": embeddings, "embeddings": embeddings,
} }
embedding = StringEmbedding( embedding = StringEmbedding(
file_path=db_summary_client.get_db_summary(), file_path=None,
vector_store_config=profile_store_config, vector_store_config=profile_store_config,
) )
if not embedding.vector_name_exist(): if not embedding.vector_name_exist():
docs = [] docs = []
docs.extend(embedding.read_batch()) for table_summary in db_summary_client.table_summaries():
for table_summary in db_summary_client.table_info_json():
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter
text_splitter = RecursiveCharacterTextSplitter( text_splitter = RecursiveCharacterTextSplitter(
chunk_size=len(table_summary), chunk_overlap=100 chunk_size=len(table_summary), chunk_overlap=0
) )
embedding = StringEmbedding( embedding = StringEmbedding(
file_path=table_summary, file_path=table_summary,
@@ -195,23 +108,8 @@ class DBSummaryClient:
text_splitter=text_splitter, text_splitter=text_splitter,
) )
docs.extend(embedding.read_batch()) docs.extend(embedding.read_batch())
if len(docs) > 0:
embedding.index_to_store(docs) embedding.index_to_store(docs)
else: else:
logger.info(f"Vector store name {vector_store_name} exist") logger.info(f"Vector store name {vector_store_name} exist")
logger.info("init db profile success...") logger.info("initialize db summary profile success...")
def _get_llm_response(query, db_input, dbsummary):
chat_param = {
"temperature": 0.7,
"max_new_tokens": 512,
"chat_session_id": uuid.uuid1(),
"user_input": query,
"db_select": db_input,
"db_summary": dbsummary,
}
chat: BaseChat = chat_factory.get_implementation(
ChatScene.InnerChatDBSummary.value, **chat_param
)
res = chat._blocking_nostream_call()
return json.loads(res)["table"]

View File

@@ -1,22 +1,22 @@
import json
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.summary.db_summary import DBSummary, TableSummary, FieldSummary, IndexSummary from pilot.summary.db_summary import DBSummary
CFG = Config() CFG = Config()
class RdbmsSummary(DBSummary): class RdbmsSummary(DBSummary):
"""Get mysql summary template.""" """Get rdbms db table summary template.
summary example:
table_name(column1(column1 comment),column2(column2 comment),column3(column3 comment) and index keys, and table comment is {table_comment})
"""
def __init__(self, name, type): def __init__(self, name, type):
self.name = name self.name = name
self.type = type self.type = type
self.summary = """{{"database_name": "{name}", "type": "{type}", "tables": "{tables}", "qps": "{qps}", "tps": {tps}}}""" self.summary_template = "{table_name}({columns})"
self.tables = {} self.tables = {}
self.tables_info = [] self.tables_info = []
self.vector_tables_info = [] self.vector_tables_info = []
# self.tables_summary = {}
self.db = CFG.LOCAL_DB_MANAGE.get_connect(name) self.db = CFG.LOCAL_DB_MANAGE.get_connect(name)
@@ -27,154 +27,41 @@ class RdbmsSummary(DBSummary):
collation=self.db.get_collation(), collation=self.db.get_collation(),
) )
tables = self.db.get_table_names() tables = self.db.get_table_names()
self.table_comments = self.db.get_table_comments(name) self.table_info_summaries = [
comment_map = {} self.get_table_summary(table_name) for table_name in tables
for table_comment in self.table_comments: ]
self.tables_info.append(
"table name:{table_name},table description:{table_comment}".format(
table_name=table_comment[0], table_comment=table_comment[1]
)
)
comment_map[table_comment[0]] = table_comment[1]
vector_table = json.dumps( def get_table_summary(self, table_name):
{"table_name": table_comment[0], "table_description": table_comment[1]} """Get table summary for table.
) example:
self.vector_tables_info.append( table_name(column1(column1 comment),column2(column2 comment),column3(column3 comment) and index keys, and table comment: {table_comment})
vector_table.encode("utf-8").decode("unicode_escape") """
) columns = []
self.table_columns_info = [] for column in self.db._inspector.get_columns(table_name):
self.table_columns_json = [] if column.get("comment"):
columns.append((f"{column['name']} ({column.get('comment')})"))
for table_name in tables:
table_summary = RdbmsTableSummary(self.db, name, table_name, comment_map)
# self.tables[table_name] = table_summary.get_summary()
self.tables[table_name] = table_summary.get_columns()
self.table_columns_info.append(table_summary.get_columns())
# self.table_columns_json.append(table_summary.get_summary_json())
table_profile = (
"table name:{table_name},table description:{table_comment}".format(
table_name=table_name,
table_comment=self.db.get_show_create_table(table_name),
)
)
self.table_columns_json.append(table_profile)
# self.tables_info.append(table_summary.get_summary())
def get_summary(self):
if CFG.SUMMARY_CONFIG == "FAST":
return self.vector_tables_info
else: else:
return self.summary.format( columns.append(f"{column['name']}")
name=self.name, type=self.type, table_info=";".join(self.tables_info)
column_str = ", ".join(columns)
index_keys = []
for index_key in self.db._inspector.get_indexes(table_name):
key_str = ", ".join(index_key["column_names"])
index_keys.append(f"{index_key['name']}(`{key_str}`) ")
table_str = self.summary_template.format(
table_name=table_name, columns=column_str
) )
if len(index_keys) > 0:
index_key_str = ", ".join(index_keys)
table_str += f", and index keys: {index_key_str}"
try:
comment = self.db._inspector.get_table_comment(table_name)
except Exception:
comment = dict(text=None)
if comment.get("text"):
table_str += f", and table comment: {comment.get('text')}"
return table_str
def get_db_summary(self): def table_summaries(self):
return self.summary.format( """Get table summaries."""
name=self.name, return self.table_info_summaries
type=self.type,
tables=";".join(self.vector_tables_info),
qps=1000,
tps=1000,
)
def get_table_summary(self):
return self.tables
def get_table_comments(self):
return self.table_comments
def table_info_json(self):
return self.table_columns_json
class RdbmsTableSummary(TableSummary):
"""Get mysql table summary template."""
def __init__(self, instance, dbname, name, comment_map):
self.name = name
self.dbname = dbname
self.summary = """database name:{dbname}, table name:{name}, have columns info: {fields}, have indexes info: {indexes}"""
self.json_summary_template = """{{"table_name": "{name}", "comment": "{comment}", "columns": "{fields}", "indexes": "{indexes}", "size_in_bytes": {size_in_bytes}, "rows": {rows}}}"""
self.fields = []
self.fields_info = []
self.indexes = []
self.indexes_info = []
self.db = instance
fields = self.db.get_fields(name)
indexes = self.db.get_indexes(name)
field_names = []
for field in fields:
field_summary = RdbmsFieldsSummary(field)
self.fields.append(field_summary)
self.fields_info.append(field_summary.get_summary())
field_names.append(field[0])
self.column_summary = """{name}({columns_info})""".format(
name=name, columns_info=",".join(field_names)
)
for index in indexes:
index_summary = RdbmsIndexSummary(index)
self.indexes.append(index_summary)
self.indexes_info.append(index_summary.get_summary())
self.json_summary = self.json_summary_template.format(
name=name,
comment=comment_map[name],
fields=self.fields_info,
indexes=self.indexes_info,
size_in_bytes=1000,
rows=1000,
)
def get_summary(self):
return self.summary.format(
name=self.name,
dbname=self.dbname,
fields=";".join(self.fields_info),
indexes=";".join(self.indexes_info),
)
def get_columns(self):
return self.column_summary
def get_summary_json(self):
return self.json_summary
class RdbmsFieldsSummary(FieldSummary):
"""Get mysql field summary template."""
def __init__(self, field):
self.name = field[0]
# self.summary = """column name:{name}, column data type:{data_type}, is nullable:{is_nullable}, default value is:{default_value}, comment is:{comment} """
# self.summary = """{"name": {name}, "type": {data_type}, "is_primary_key": {is_nullable}, "comment":{comment}, "default":{default_value}}"""
self.data_type = field[1]
self.default_value = field[2]
self.is_nullable = field[3]
self.comment = field[4]
def get_summary(self):
return '{{"name": "{name}", "type": "{data_type}", "is_primary_key": "{is_nullable}", "comment": "{comment}", "default": "{default_value}"}}'.format(
name=self.name,
data_type=self.data_type,
is_nullable=self.is_nullable,
default_value=self.default_value,
comment=self.comment,
)
class RdbmsIndexSummary(IndexSummary):
"""Get mysql index summary template."""
def __init__(self, index):
self.name = index[0]
# self.summary = """index name:{name}, index bind columns:{bind_fields}"""
self.summary_template = '{{"name": "{name}", "columns": {bind_fields}}}'
self.bind_fields = index[1]
def get_summary(self):
return self.summary_template.format(
name=self.name, bind_fields=self.bind_fields
)