DB-GPT/dbgpt/rag/summary/db_summary_client.py
Fangyin Cheng 65c875db20
feat(core): Support higher-order operators (#1984)
Co-authored-by: 谨欣 <echo.cmy@antgroup.com>
2024-09-09 10:15:37 +08:00

143 lines
5.1 KiB
Python

"""DBSummaryClient class."""
import logging
import traceback
from typing import List
from dbgpt._private.config import Config
from dbgpt.component import SystemApp
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG
from dbgpt.rag.summary.gdbms_db_summary import GdbmsSummary
from dbgpt.rag.summary.rdbms_db_summary import RdbmsSummary
logger = logging.getLogger(__name__)
CFG = Config()
class DBSummaryClient:
"""The client for DBSummary.
DB Summary client, provide db_summary_embedding(put db profile and table profile
summary into vector store), get_similar_tables method(get user query related tables
info)
Args:
system_app (SystemApp): Main System Application class that manages the
lifecycle and registration of components..
"""
def __init__(self, system_app: SystemApp):
"""Create a new DBSummaryClient."""
self.system_app = system_app
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
embedding_factory: EmbeddingFactory = self.system_app.get_component(
"embedding_factory", component_type=EmbeddingFactory
)
self.embeddings = embedding_factory.create(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
)
def db_summary_embedding(self, dbname, db_type):
"""Put db profile and table profile summary into vector store."""
db_summary_client = self.create_summary_client(dbname, db_type)
self.init_db_profile(db_summary_client, dbname)
logger.info("db summary embedding success")
def get_db_summary(self, dbname, query, topk) -> List[str]:
"""Get user query related tables info."""
from dbgpt.serve.rag.connector import VectorStoreConnector
from dbgpt.storage.vector_store.base import VectorStoreConfig
vector_store_config = VectorStoreConfig(name=dbname + "_profile")
vector_connector = VectorStoreConnector.from_default(
CFG.VECTOR_STORE_TYPE,
embedding_fn=self.embeddings,
vector_store_config=vector_store_config,
)
from dbgpt.rag.retriever.db_schema import DBSchemaRetriever
retriever = DBSchemaRetriever(
top_k=topk, index_store=vector_connector.index_client
)
table_docs = retriever.retrieve(query)
ans = [d.content for d in table_docs]
return ans
def init_db_summary(self):
"""Initialize db summary profile."""
db_mange = CFG.local_db_manager
dbs = db_mange.get_db_list()
for item in dbs:
try:
self.db_summary_embedding(item["db_name"], item["db_type"])
except Exception as e:
message = traceback.format_exc()
logger.warn(
f'{item["db_name"]}, {item["db_type"]} summary error!{str(e)}, '
f"detail: {message}"
)
def init_db_profile(self, db_summary_client, dbname):
"""Initialize db summary profile.
Args:
db_summary_client(DBSummaryClient): DB Summary Client
dbname(str): dbname
"""
vector_store_name = dbname + "_profile"
from dbgpt.serve.rag.connector import VectorStoreConnector
from dbgpt.storage.vector_store.base import VectorStoreConfig
vector_store_config = VectorStoreConfig(name=vector_store_name)
vector_connector = VectorStoreConnector.from_default(
CFG.VECTOR_STORE_TYPE,
self.embeddings,
vector_store_config=vector_store_config,
)
if not vector_connector.vector_name_exists():
from dbgpt.rag.assembler.db_schema import DBSchemaAssembler
db_assembler = DBSchemaAssembler.load_from_connection(
connector=db_summary_client.db,
index_store=vector_connector.index_client,
)
if len(db_assembler.get_chunks()) > 0:
db_assembler.persist()
else:
logger.info(f"Vector store name {vector_store_name} exist")
logger.info("initialize db summary profile success...")
def delete_db_profile(self, dbname):
"""Delete db profile."""
vector_store_name = dbname + "_profile"
from dbgpt.serve.rag.connector import VectorStoreConnector
from dbgpt.storage.vector_store.base import VectorStoreConfig
vector_store_config = VectorStoreConfig(name=vector_store_name)
vector_connector = VectorStoreConnector.from_default(
CFG.VECTOR_STORE_TYPE,
self.embeddings,
vector_store_config=vector_store_config,
)
vector_connector.delete_vector_name(vector_store_name)
logger.info(f"delete db profile {dbname} success")
@staticmethod
def create_summary_client(dbname: str, db_type: str):
"""
Create a summary client based on the database type.
Args:
dbname (str): The name of the database.
db_type (str): The type of the database.
"""
if "graph" in db_type:
return GdbmsSummary(dbname, db_type)
else:
return RdbmsSummary(dbname, db_type)