mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-06 19:40:13 +00:00
refactor: RAG Refactor (#985)
Co-authored-by: Aralhi <xiaoping0501@gmail.com> Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
@@ -23,41 +23,43 @@ class DBSummaryClient:
|
||||
|
||||
def __init__(self, system_app: SystemApp):
|
||||
self.system_app = system_app
|
||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||
|
||||
def db_summary_embedding(self, dbname, db_type):
|
||||
"""put db profile and table profile summary into vector store"""
|
||||
from dbgpt.rag.embedding_engine.embedding_factory import EmbeddingFactory
|
||||
|
||||
db_summary_client = RdbmsSummary(dbname, db_type)
|
||||
embedding_factory = self.system_app.get_component(
|
||||
"embedding_factory", EmbeddingFactory
|
||||
)
|
||||
embeddings = embedding_factory.create(
|
||||
self.embeddings = embedding_factory.create(
|
||||
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
|
||||
)
|
||||
self.init_db_profile(db_summary_client, dbname, embeddings)
|
||||
|
||||
def db_summary_embedding(self, dbname, db_type):
|
||||
"""put db profile and table profile summary into vector store"""
|
||||
|
||||
db_summary_client = RdbmsSummary(dbname, db_type)
|
||||
|
||||
self.init_db_profile(db_summary_client, dbname)
|
||||
|
||||
logger.info("db summary embedding success")
|
||||
|
||||
def get_db_summary(self, dbname, query, topk):
|
||||
"""get user query related tables info"""
|
||||
from dbgpt.rag.embedding_engine.embedding_engine import EmbeddingEngine
|
||||
from dbgpt.rag.embedding_engine.embedding_factory import EmbeddingFactory
|
||||
|
||||
vector_store_config = {
|
||||
"vector_store_name": dbname + "_profile",
|
||||
"vector_store_type": CFG.VECTOR_STORE_TYPE,
|
||||
}
|
||||
embedding_factory = CFG.SYSTEM_APP.get_component(
|
||||
"embedding_factory", EmbeddingFactory
|
||||
)
|
||||
knowledge_embedding_client = EmbeddingEngine(
|
||||
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
|
||||
vector_store_config = VectorStoreConfig(name=dbname + "_profile")
|
||||
vector_connector = VectorStoreConnector.from_default(
|
||||
CFG.VECTOR_STORE_TYPE,
|
||||
embedding_fn=self.embeddings,
|
||||
vector_store_config=vector_store_config,
|
||||
embedding_factory=embedding_factory,
|
||||
)
|
||||
table_docs = knowledge_embedding_client.similar_search(query, topk)
|
||||
ans = [d.page_content for d in table_docs]
|
||||
from dbgpt.rag.retriever.db_struct import DBStructRetriever
|
||||
|
||||
retriever = DBStructRetriever(
|
||||
top_k=topk, vector_store_connector=vector_connector
|
||||
)
|
||||
table_docs = retriever.retrieve(query)
|
||||
ans = [d.content for d in table_docs]
|
||||
return ans
|
||||
|
||||
def init_db_summary(self):
|
||||
@@ -73,41 +75,30 @@ class DBSummaryClient:
|
||||
f'{item["db_name"]}, {item["db_type"]} summary error!{str(e)}, detail: {message}'
|
||||
)
|
||||
|
||||
def init_db_profile(self, db_summary_client, dbname, embeddings):
|
||||
def init_db_profile(self, db_summary_client, dbname):
|
||||
"""db profile initialization
|
||||
Args:
|
||||
db_summary_client(DBSummaryClient): DB Summary Client
|
||||
dbname(str): dbname
|
||||
embeddings(SourceEmbedding): embedding for read string document
|
||||
"""
|
||||
from dbgpt.rag.embedding_engine.string_embedding import StringEmbedding
|
||||
|
||||
vector_store_name = dbname + "_profile"
|
||||
profile_store_config = {
|
||||
"vector_store_name": vector_store_name,
|
||||
"vector_store_type": CFG.VECTOR_STORE_TYPE,
|
||||
"embeddings": embeddings,
|
||||
}
|
||||
embedding = StringEmbedding(
|
||||
file_path=None,
|
||||
vector_store_config=profile_store_config,
|
||||
)
|
||||
if not embedding.vector_name_exist():
|
||||
docs = []
|
||||
for table_summary in db_summary_client.table_summaries():
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=len(table_summary), chunk_overlap=0
|
||||
)
|
||||
embedding = StringEmbedding(
|
||||
file_path=table_summary,
|
||||
vector_store_config=profile_store_config,
|
||||
text_splitter=text_splitter,
|
||||
)
|
||||
docs.extend(embedding.read_batch())
|
||||
if len(docs) > 0:
|
||||
embedding.index_to_store(docs)
|
||||
vector_store_config = VectorStoreConfig(name=vector_store_name)
|
||||
vector_connector = VectorStoreConnector.from_default(
|
||||
CFG.VECTOR_STORE_TYPE,
|
||||
self.embeddings,
|
||||
vector_store_config=vector_store_config,
|
||||
)
|
||||
if not vector_connector.vector_name_exists():
|
||||
from dbgpt.serve.rag.assembler.db_struct import DBStructAssembler
|
||||
|
||||
db_assembler = DBStructAssembler.load_from_connection(
|
||||
connection=db_summary_client.db, vector_store_connector=vector_connector
|
||||
)
|
||||
if len(db_assembler.get_chunks()) > 0:
|
||||
db_assembler.persist()
|
||||
else:
|
||||
logger.info(f"Vector store name {vector_store_name} exist")
|
||||
logger.info("initialize db summary profile success...")
|
||||
|
0
dbgpt/rag/summary/tests/__init__.py
Normal file
0
dbgpt/rag/summary/tests/__init__.py
Normal file
68
dbgpt/rag/summary/tests/test_rdbms_summary.py
Normal file
68
dbgpt/rag/summary/tests/test_rdbms_summary.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import unittest
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from dbgpt.rag.summary.rdbms_db_summary import RdbmsSummary
|
||||
|
||||
|
||||
class MockRDBMSDatabase(object):
|
||||
def get_users(self):
|
||||
return "user1, user2"
|
||||
|
||||
def get_grants(self):
|
||||
return "grant1, grant2"
|
||||
|
||||
def get_charset(self):
|
||||
return "utf8"
|
||||
|
||||
def get_collation(self):
|
||||
return "utf8_general_ci"
|
||||
|
||||
def get_table_names(self):
|
||||
return ["table1", "table2"]
|
||||
|
||||
def get_columns(self, table_name):
|
||||
if table_name == "table1":
|
||||
return [{"name": "column1", "comment": "first column"}, {"name": "column2"}]
|
||||
return [{"name": "column1"}]
|
||||
|
||||
def get_indexes(self, table_name):
|
||||
return [{"name": "index1", "column_names": ["column1"]}]
|
||||
|
||||
def get_table_comment(self, table_name):
|
||||
return {"text": f"{table_name} comment"}
|
||||
|
||||
|
||||
class TestRdbmsSummary(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.mock_local_db_manage = Mock()
|
||||
self.mock_local_db_manage.get_connect.return_value = MockRDBMSDatabase()
|
||||
self.patcher = patch(
|
||||
"dbgpt.rag.summary.rdbms_db_summary.CFG.LOCAL_DB_MANAGE",
|
||||
new=self.mock_local_db_manage,
|
||||
)
|
||||
self.patcher.start()
|
||||
|
||||
def test_rdbms_summary_initialization(self):
|
||||
rdbms_summary = RdbmsSummary(name="test_db", type="test_type")
|
||||
self.assertEqual(rdbms_summary.name, "test_db")
|
||||
self.assertEqual(rdbms_summary.type, "test_type")
|
||||
self.assertTrue("user info :user1, user2" in rdbms_summary.metadata)
|
||||
self.assertTrue("grant info:grant1, grant2" in rdbms_summary.metadata)
|
||||
self.assertTrue("charset:utf8" in rdbms_summary.metadata)
|
||||
self.assertTrue("collation:utf8_general_ci" in rdbms_summary.metadata)
|
||||
|
||||
def test_table_summaries(self):
|
||||
rdbms_summary = RdbmsSummary(name="test_db", type="test_type")
|
||||
summaries = rdbms_summary.table_summaries()
|
||||
self.assertTrue(
|
||||
"table1(column1 (first column), column2), and index keys: index1(`column1`) , and table comment: table1 comment"
|
||||
in summaries
|
||||
)
|
||||
self.assertTrue(
|
||||
"table2(column1), and index keys: index1(`column1`) , and table comment: table2 comment"
|
||||
in summaries
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Reference in New Issue
Block a user