Files
DB-GPT/pilot/summary/db_summary_client.py
yhjun1026 35536df73e Merge branch 'main' into dev_ty_06_end
# Conflicts:
#	pilot/common/sql_database.py
#	pilot/configs/config.py
#	pilot/scene/base.py
#	pilot/scene/chat_dashboard/chat.py
#	pilot/scene/chat_dashboard/data_preparation/report_schma.py
#	pilot/scene/chat_dashboard/prompt.py
2023-07-21 16:46:59 +08:00

182 lines
7.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import json
import uuid
from langchain.embeddings import HuggingFaceEmbeddings, logger
from pilot.configs.config import Config
from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH
from pilot.scene.base import ChatScene
from pilot.scene.base_chat import BaseChat
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
from pilot.embedding_engine.string_embedding import StringEmbedding
from pilot.summary.mysql_db_summary import MysqlSummary
from pilot.scene.chat_factory import ChatFactory
from pilot.common.schema import DBType
CFG = Config()
chat_factory = ChatFactory()
class DBSummaryClient:
"""db summary client, provide db_summary_embedding(put db profile and table profile summary into vector store)
, get_similar_tables method(get user query related tables info)
"""
def __init__(self):
pass
def db_summary_embedding(self, dbname, db_type):
"""put db profile and table profile summary into vector store"""
if DBType.Mysql.value() == db_type:
db_summary_client = MysqlSummary(dbname)
else:
raise ValueError("Unsupport summary DbType" + db_type)
embeddings = HuggingFaceEmbeddings(
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
)
vector_store_config = {
"vector_store_name": dbname + "_summary",
"vector_store_type": CFG.VECTOR_STORE_TYPE,
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
"embeddings": embeddings,
}
embedding = StringEmbedding(
file_path=db_summary_client.get_summery(),
vector_store_config=vector_store_config,
)
self.init_db_profile(db_summary_client, dbname, embeddings)
if not embedding.vector_name_exist():
if CFG.SUMMARY_CONFIG == "FAST":
for vector_table_info in db_summary_client.get_summery():
embedding = StringEmbedding(
vector_table_info,
vector_store_config,
)
embedding.source_embedding()
else:
embedding = StringEmbedding(
file_path=db_summary_client.get_summery(),
vector_store_config=vector_store_config,
)
embedding.source_embedding()
for (
table_name,
table_summary,
) in db_summary_client.get_table_summary().items():
table_vector_store_config = {
"vector_store_name": dbname + "_" + table_name + "_ts",
"vector_store_type": CFG.VECTOR_STORE_TYPE,
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
"embeddings": embeddings,
}
embedding = StringEmbedding(
table_summary,
table_vector_store_config,
)
embedding.source_embedding()
logger.info("db summary embedding success")
def get_db_summary(self, dbname, query, topk):
vector_store_config = {
"vector_store_name": dbname + "_profile",
"vector_store_type": CFG.VECTOR_STORE_TYPE,
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
}
knowledge_embedding_client = EmbeddingEngine(
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config=vector_store_config,
)
table_docs = knowledge_embedding_client.similar_search(query, topk)
ans = [d.page_content for d in table_docs]
return ans
def get_similar_tables(self, dbname, query, topk):
"""get user query related tables info"""
vector_store_config = {
"vector_store_name": dbname + "_summary",
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
"vector_store_type": CFG.VECTOR_STORE_TYPE,
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
}
knowledge_embedding_client = EmbeddingEngine(
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config=vector_store_config,
)
if CFG.SUMMARY_CONFIG == "FAST":
table_docs = knowledge_embedding_client.similar_search(query, topk)
related_tables = [
json.loads(table_doc.page_content)["table_name"]
for table_doc in table_docs
]
else:
table_docs = knowledge_embedding_client.similar_search(query, 1)
# prompt = KnownLedgeBaseQA.build_db_summary_prompt(
# query, table_docs[0].page_content
# )
related_tables = _get_llm_response(
query, dbname, table_docs[0].page_content
)
related_table_summaries = []
for table in related_tables:
vector_store_config = {
"vector_store_name": dbname + "_" + table + "_ts",
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
"vector_store_type": CFG.VECTOR_STORE_TYPE,
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
}
knowledge_embedding_client = EmbeddingEngine(
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config=vector_store_config,
)
table_summery = knowledge_embedding_client.similar_search(query, 1)
related_table_summaries.append(table_summery[0].page_content)
return related_table_summaries
def init_db_summary(self):
db_mange = CFG.LOCAL_DB_MANAGE
dbs = db_mange.get_db_list()
for item in dbs:
self.db_summary_embedding(item["db_name"], item["db_type"])
def init_db_profile(self, db_summary_client, dbname, embeddings):
profile_store_config = {
"vector_store_name": dbname + "_profile",
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
"vector_store_type": CFG.VECTOR_STORE_TYPE,
"embeddings": embeddings,
}
embedding = StringEmbedding(
file_path=db_summary_client.get_db_summery(),
vector_store_config=profile_store_config,
)
if not embedding.vector_name_exist():
docs = []
docs.extend(embedding.read_batch())
for table_summary in db_summary_client.table_info_json():
embedding = StringEmbedding(
table_summary,
profile_store_config,
)
docs.extend(embedding.read_batch())
embedding.index_to_store(docs)
logger.info("init db profile success...")
def _get_llm_response(query, db_input, dbsummary):
chat_param = {
"temperature": 0.7,
"max_new_tokens": 512,
"chat_session_id": uuid.uuid1(),
"user_input": query,
"db_select": db_input,
"db_summary": dbsummary,
}
chat: BaseChat = chat_factory.get_implementation(
ChatScene.InnerChatDBSummary.value, **chat_param
)
res = chat.nostream_call()
return json.loads(res)["table"]