mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-02 16:50:24 +00:00
143 lines
5.1 KiB
Python
143 lines
5.1 KiB
Python
"""DBSummaryClient class."""
|
|
|
|
import logging
|
|
import traceback
|
|
from typing import List
|
|
|
|
from dbgpt._private.config import Config
|
|
from dbgpt.component import SystemApp
|
|
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG
|
|
from dbgpt.rag.summary.gdbms_db_summary import GdbmsSummary
|
|
from dbgpt.rag.summary.rdbms_db_summary import RdbmsSummary
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
CFG = Config()
|
|
|
|
|
|
class DBSummaryClient:
|
|
"""The client for DBSummary.
|
|
|
|
DB Summary client, provide db_summary_embedding(put db profile and table profile
|
|
summary into vector store), get_similar_tables method(get user query related tables
|
|
info)
|
|
|
|
Args:
|
|
system_app (SystemApp): Main System Application class that manages the
|
|
lifecycle and registration of components..
|
|
"""
|
|
|
|
def __init__(self, system_app: SystemApp):
|
|
"""Create a new DBSummaryClient."""
|
|
self.system_app = system_app
|
|
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
|
|
|
embedding_factory: EmbeddingFactory = self.system_app.get_component(
|
|
"embedding_factory", component_type=EmbeddingFactory
|
|
)
|
|
self.embeddings = embedding_factory.create(
|
|
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
|
|
)
|
|
|
|
def db_summary_embedding(self, dbname, db_type):
|
|
"""Put db profile and table profile summary into vector store."""
|
|
db_summary_client = self.create_summary_client(dbname, db_type)
|
|
|
|
self.init_db_profile(db_summary_client, dbname)
|
|
|
|
logger.info("db summary embedding success")
|
|
|
|
def get_db_summary(self, dbname, query, topk) -> List[str]:
|
|
"""Get user query related tables info."""
|
|
from dbgpt.serve.rag.connector import VectorStoreConnector
|
|
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
|
|
|
vector_store_config = VectorStoreConfig(name=dbname + "_profile")
|
|
vector_connector = VectorStoreConnector.from_default(
|
|
CFG.VECTOR_STORE_TYPE,
|
|
embedding_fn=self.embeddings,
|
|
vector_store_config=vector_store_config,
|
|
)
|
|
from dbgpt.rag.retriever.db_schema import DBSchemaRetriever
|
|
|
|
retriever = DBSchemaRetriever(
|
|
top_k=topk, index_store=vector_connector.index_client
|
|
)
|
|
table_docs = retriever.retrieve(query)
|
|
ans = [d.content for d in table_docs]
|
|
return ans
|
|
|
|
def init_db_summary(self):
|
|
"""Initialize db summary profile."""
|
|
db_mange = CFG.local_db_manager
|
|
dbs = db_mange.get_db_list()
|
|
for item in dbs:
|
|
try:
|
|
self.db_summary_embedding(item["db_name"], item["db_type"])
|
|
except Exception as e:
|
|
message = traceback.format_exc()
|
|
logger.warn(
|
|
f'{item["db_name"]}, {item["db_type"]} summary error!{str(e)}, '
|
|
f"detail: {message}"
|
|
)
|
|
|
|
def init_db_profile(self, db_summary_client, dbname):
|
|
"""Initialize db summary profile.
|
|
|
|
Args:
|
|
db_summary_client(DBSummaryClient): DB Summary Client
|
|
dbname(str): dbname
|
|
"""
|
|
vector_store_name = dbname + "_profile"
|
|
from dbgpt.serve.rag.connector import VectorStoreConnector
|
|
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
|
|
|
vector_store_config = VectorStoreConfig(name=vector_store_name)
|
|
vector_connector = VectorStoreConnector.from_default(
|
|
CFG.VECTOR_STORE_TYPE,
|
|
self.embeddings,
|
|
vector_store_config=vector_store_config,
|
|
)
|
|
if not vector_connector.vector_name_exists():
|
|
from dbgpt.rag.assembler.db_schema import DBSchemaAssembler
|
|
|
|
db_assembler = DBSchemaAssembler.load_from_connection(
|
|
connector=db_summary_client.db,
|
|
index_store=vector_connector.index_client,
|
|
)
|
|
|
|
if len(db_assembler.get_chunks()) > 0:
|
|
db_assembler.persist()
|
|
else:
|
|
logger.info(f"Vector store name {vector_store_name} exist")
|
|
logger.info("initialize db summary profile success...")
|
|
|
|
def delete_db_profile(self, dbname):
|
|
"""Delete db profile."""
|
|
vector_store_name = dbname + "_profile"
|
|
from dbgpt.serve.rag.connector import VectorStoreConnector
|
|
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
|
|
|
vector_store_config = VectorStoreConfig(name=vector_store_name)
|
|
vector_connector = VectorStoreConnector.from_default(
|
|
CFG.VECTOR_STORE_TYPE,
|
|
self.embeddings,
|
|
vector_store_config=vector_store_config,
|
|
)
|
|
vector_connector.delete_vector_name(vector_store_name)
|
|
logger.info(f"delete db profile {dbname} success")
|
|
|
|
@staticmethod
|
|
def create_summary_client(dbname: str, db_type: str):
|
|
"""
|
|
Create a summary client based on the database type.
|
|
|
|
Args:
|
|
dbname (str): The name of the database.
|
|
db_type (str): The type of the database.
|
|
"""
|
|
if "graph" in db_type:
|
|
return GdbmsSummary(dbname, db_type)
|
|
else:
|
|
return RdbmsSummary(dbname, db_type)
|