Merge remote-tracking branch 'origin/main' into feat_rag_graph

This commit is contained in:
aries_ckt 2023-10-16 12:48:08 +08:00
commit 71c31c3e2e

View File

@ -22,8 +22,10 @@ chat_factory = ChatFactory()
class DBSummaryClient:
"""db summary client, provide db_summary_embedding(put db profile and table profile summary into vector store)
"""DB Summary client, provide db_summary_embedding(put db profile and table profile summary into vector store)
, get_similar_tables method(get user query related tables info)
Args:
system_app (SystemApp): Main System Application class that manages the lifecycle and registration of components..
"""
def __init__(self, system_app: SystemApp):
@ -160,6 +162,12 @@ class DBSummaryClient:
)
def init_db_profile(self, db_summary_client, dbname, embeddings):
"""db profile initialization
Args:
db_summary_client(DBSummaryClient): DB Summary Client
dbname(str): dbname
embeddings(SourceEmbedding): embedding for read string document
"""
from pilot.embedding_engine.string_embedding import StringEmbedding
vector_store_name = dbname + "_profile"
@ -176,9 +184,15 @@ class DBSummaryClient:
docs = []
docs.extend(embedding.read_batch())
for table_summary in db_summary_client.table_info_json():
from langchain.text_splitter import RecursiveCharacterTextSplitter
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=len(table_summary), chunk_overlap=100
)
embedding = StringEmbedding(
table_summary,
profile_store_config,
file_path=table_summary,
vector_store_config=profile_store_config,
text_splitter=text_splitter,
)
docs.extend(embedding.read_batch())
embedding.index_to_store(docs)