mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-28 21:12:13 +00:00
130 lines
4.8 KiB
Python
130 lines
4.8 KiB
Python
import json
|
|
import uuid
|
|
|
|
from langchain.embeddings import HuggingFaceEmbeddings, logger
|
|
|
|
from pilot.configs.config import Config
|
|
from pilot.configs.model_config import LLM_MODEL_CONFIG
|
|
from pilot.scene.base import ChatScene
|
|
from pilot.scene.base_chat import BaseChat
|
|
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
|
|
from pilot.source_embedding.string_embedding import StringEmbedding
|
|
from pilot.summary.mysql_db_summary import MysqlSummary
|
|
from pilot.scene.chat_factory import ChatFactory
|
|
|
|
CFG = Config()
|
|
chat_factory = ChatFactory()
|
|
|
|
|
|
class DBSummaryClient:
|
|
"""db summary client, provide db_summary_embedding(put db profile and table profile summary into vector store)
|
|
, get_similar_tables method(get user query related tables info)
|
|
"""
|
|
|
|
def __init__(self):
|
|
pass
|
|
|
|
def db_summary_embedding(self, dbname):
|
|
"""put db profile and table profile summary into vector store"""
|
|
if CFG.LOCAL_DB_HOST is not None and CFG.LOCAL_DB_PORT is not None:
|
|
db_summary_client = MysqlSummary(dbname)
|
|
embeddings = HuggingFaceEmbeddings(
|
|
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
|
|
)
|
|
vector_store_config = {
|
|
"vector_store_name": dbname + "_profile",
|
|
"embeddings": embeddings,
|
|
}
|
|
embedding = StringEmbedding(
|
|
file_path=db_summary_client.get_summery(),
|
|
vector_store_config=vector_store_config,
|
|
)
|
|
if not embedding.vector_name_exist():
|
|
if CFG.SUMMARY_CONFIG == "FAST":
|
|
for vector_table_info in db_summary_client.get_summery():
|
|
embedding = StringEmbedding(
|
|
vector_table_info,
|
|
vector_store_config,
|
|
)
|
|
embedding.source_embedding()
|
|
else:
|
|
embedding = StringEmbedding(
|
|
file_path=db_summary_client.get_summery(),
|
|
vector_store_config=vector_store_config,
|
|
)
|
|
embedding.source_embedding()
|
|
for (
|
|
table_name,
|
|
table_summary,
|
|
) in db_summary_client.get_table_summary().items():
|
|
table_vector_store_config = {
|
|
"vector_store_name": dbname + "_" + table_name + "_ts",
|
|
"embeddings": embeddings,
|
|
}
|
|
embedding = StringEmbedding(
|
|
table_summary,
|
|
table_vector_store_config,
|
|
)
|
|
embedding.source_embedding()
|
|
|
|
logger.info("db summary embedding success")
|
|
|
|
def get_similar_tables(self, dbname, query, topk):
|
|
"""get user query related tables info"""
|
|
vector_store_config = {
|
|
"vector_store_name": dbname + "_profile",
|
|
}
|
|
knowledge_embedding_client = KnowledgeEmbedding(
|
|
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
|
vector_store_config=vector_store_config,
|
|
)
|
|
if CFG.SUMMARY_CONFIG == "FAST":
|
|
table_docs = knowledge_embedding_client.similar_search(query, topk)
|
|
related_tables = [
|
|
json.loads(table_doc.page_content)["table_name"]
|
|
for table_doc in table_docs
|
|
]
|
|
else:
|
|
table_docs = knowledge_embedding_client.similar_search(query, 1)
|
|
# prompt = KnownLedgeBaseQA.build_db_summary_prompt(
|
|
# query, table_docs[0].page_content
|
|
# )
|
|
related_tables = _get_llm_response(
|
|
query, dbname, table_docs[0].page_content
|
|
)
|
|
related_table_summaries = []
|
|
for table in related_tables:
|
|
vector_store_config = {
|
|
"vector_store_name": dbname + "_" + table + "_ts",
|
|
}
|
|
knowledge_embedding_client = KnowledgeEmbedding(
|
|
file_path="",
|
|
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
|
vector_store_config=vector_store_config,
|
|
)
|
|
table_summery = knowledge_embedding_client.similar_search(query, 1)
|
|
related_table_summaries.append(table_summery[0].page_content)
|
|
return related_table_summaries
|
|
|
|
def init_db_summary(self):
|
|
db = CFG.local_db
|
|
dbs = db.get_database_list()
|
|
for dbname in dbs:
|
|
self.db_summary_embedding(dbname)
|
|
|
|
|
|
def _get_llm_response(query, db_input, dbsummary):
|
|
chat_param = {
|
|
"temperature": 0.7,
|
|
"max_new_tokens": 512,
|
|
"chat_session_id": uuid.uuid1(),
|
|
"user_input": query,
|
|
"db_select": db_input,
|
|
"db_summary": dbsummary,
|
|
}
|
|
chat: BaseChat = chat_factory.get_implementation(
|
|
ChatScene.InnerChatDBSummary.value, **chat_param
|
|
)
|
|
res = chat.nostream_call()
|
|
return json.loads(res)["table"]
|