mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-03 18:17:45 +00:00
Feat rdb summary wide table (#2035)
Co-authored-by: dongzhancai1 <dongzhancai1@jd.com> Co-authored-by: dong <dongzhancai@iie2.com>
This commit is contained in:
@@ -1,18 +1,23 @@
|
||||
"""DBSchema retriever."""
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
from functools import reduce
|
||||
from typing import List, Optional, cast
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.datasource.base import BaseConnector
|
||||
from dbgpt.rag.index.base import IndexStoreBase
|
||||
from dbgpt.rag.retriever.base import BaseRetriever
|
||||
from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker
|
||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||
from dbgpt.util.chat_util import run_async_tasks
|
||||
from dbgpt.rag.summary.gdbms_db_summary import _parse_db_summary
|
||||
from dbgpt.serve.rag.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
from dbgpt.storage.vector_store.filters import MetadataFilter, MetadataFilters
|
||||
from dbgpt.util.chat_util import run_tasks
|
||||
from dbgpt.util.executor_utils import blocking_func_to_async_no_executor
|
||||
from dbgpt.util.tracer import root_tracer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class DBSchemaRetriever(BaseRetriever):
|
||||
@@ -20,7 +25,9 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index_store: IndexStoreBase,
|
||||
table_vector_store_connector: VectorStoreConnector,
|
||||
field_vector_store_connector: VectorStoreConnector = None,
|
||||
separator: str = "--table-field-separator--",
|
||||
top_k: int = 4,
|
||||
connector: Optional[BaseConnector] = None,
|
||||
query_rewrite: bool = False,
|
||||
@@ -30,7 +37,11 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
"""Create DBSchemaRetriever.
|
||||
|
||||
Args:
|
||||
index_store(IndexStore): index connector
|
||||
table_vector_store_connector: VectorStoreConnector
|
||||
to load and retrieve table info.
|
||||
field_vector_store_connector: VectorStoreConnector
|
||||
to load and retrieve field info.
|
||||
separator: field/table separator
|
||||
top_k (int): top k
|
||||
connector (Optional[BaseConnector]): RDBMSConnector.
|
||||
query_rewrite (bool): query rewrite
|
||||
@@ -70,34 +81,42 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
|
||||
|
||||
connector = _create_temporary_connection()
|
||||
vector_store_config = ChromaVectorConfig(name="vector_store_name")
|
||||
embedding_model_path = "{your_embedding_model_path}"
|
||||
embedding_fn = embedding_factory.create(model_name=embedding_model_path)
|
||||
config = ChromaVectorConfig(
|
||||
persist_path=PILOT_PATH,
|
||||
name="dbschema_rag_test",
|
||||
embedding_fn=DefaultEmbeddingFactory(
|
||||
default_model_name=os.path.join(
|
||||
MODEL_PATH, "text2vec-large-chinese"
|
||||
),
|
||||
).create(),
|
||||
vector_connector = VectorStoreConnector.from_default(
|
||||
"Chroma",
|
||||
vector_store_config=vector_store_config,
|
||||
embedding_fn=embedding_fn,
|
||||
)
|
||||
|
||||
vector_store = ChromaStore(config)
|
||||
# get db struct retriever
|
||||
retriever = DBSchemaRetriever(
|
||||
top_k=3,
|
||||
index_store=vector_store,
|
||||
vector_store_connector=vector_connector,
|
||||
connector=connector,
|
||||
)
|
||||
chunks = retriever.retrieve("show columns from table")
|
||||
result = [chunk.content for chunk in chunks]
|
||||
print(f"db struct rag example results:{result}")
|
||||
"""
|
||||
self._separator = separator
|
||||
self._top_k = top_k
|
||||
self._connector = connector
|
||||
self._query_rewrite = query_rewrite
|
||||
self._index_store = index_store
|
||||
self._table_vector_store_connector = table_vector_store_connector
|
||||
field_vector_store_config = VectorStoreConfig(
|
||||
name=table_vector_store_connector.vector_store_config.name + "_field"
|
||||
)
|
||||
self._field_vector_store_connector = (
|
||||
field_vector_store_connector
|
||||
or VectorStoreConnector.from_default(
|
||||
os.getenv("VECTOR_STORE_TYPE", "Chroma"),
|
||||
self._table_vector_store_connector.current_embeddings,
|
||||
vector_store_config=field_vector_store_config,
|
||||
)
|
||||
)
|
||||
self._need_embeddings = False
|
||||
if self._index_store:
|
||||
if self._table_vector_store_connector:
|
||||
self._need_embeddings = True
|
||||
self._rerank = rerank or DefaultRanker(self._top_k)
|
||||
|
||||
@@ -114,15 +133,8 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
List[Chunk]: list of chunks
|
||||
"""
|
||||
if self._need_embeddings:
|
||||
queries = [query]
|
||||
candidates = [
|
||||
self._index_store.similar_search(query, self._top_k, filters)
|
||||
for query in queries
|
||||
]
|
||||
return cast(List[Chunk], reduce(lambda x, y: x + y, candidates))
|
||||
return self._similarity_search(query, filters)
|
||||
else:
|
||||
if not self._connector:
|
||||
raise RuntimeError("RDBMSConnector connection is required.")
|
||||
table_summaries = _parse_db_summary(self._connector)
|
||||
return [Chunk(content=table_summary) for table_summary in table_summaries]
|
||||
|
||||
@@ -156,30 +168,11 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
Returns:
|
||||
List[Chunk]: list of chunks
|
||||
"""
|
||||
if self._need_embeddings:
|
||||
queries = [query]
|
||||
candidates = [
|
||||
self._similarity_search(
|
||||
query, filters, root_tracer.get_current_span_id()
|
||||
)
|
||||
for query in queries
|
||||
]
|
||||
result_candidates = await run_async_tasks(
|
||||
tasks=candidates, concurrency_limit=1
|
||||
)
|
||||
return cast(List[Chunk], reduce(lambda x, y: x + y, result_candidates))
|
||||
else:
|
||||
from dbgpt.rag.summary.rdbms_db_summary import ( # noqa: F401
|
||||
_parse_db_summary,
|
||||
)
|
||||
|
||||
table_summaries = await run_async_tasks(
|
||||
tasks=[self._aparse_db_summary(root_tracer.get_current_span_id())],
|
||||
concurrency_limit=1,
|
||||
)
|
||||
return [
|
||||
Chunk(content=table_summary) for table_summary in table_summaries[0]
|
||||
]
|
||||
return await blocking_func_to_async_no_executor(
|
||||
func=self._retrieve,
|
||||
query=query,
|
||||
filters=filters,
|
||||
)
|
||||
|
||||
async def _aretrieve_with_score(
|
||||
self,
|
||||
@@ -196,34 +189,40 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
"""
|
||||
return await self._aretrieve(query, filters)
|
||||
|
||||
async def _similarity_search(
|
||||
self,
|
||||
query,
|
||||
filters: Optional[MetadataFilters] = None,
|
||||
parent_span_id: Optional[str] = None,
|
||||
def _retrieve_field(self, table_chunk: Chunk, query) -> Chunk:
|
||||
metadata = table_chunk.metadata
|
||||
metadata["part"] = "field"
|
||||
filters = [MetadataFilter(key=k, value=v) for k, v in metadata.items()]
|
||||
field_chunks = self._field_vector_store_connector.similar_search_with_scores(
|
||||
query, self._top_k, 0, MetadataFilters(filters=filters)
|
||||
)
|
||||
field_contents = [chunk.content for chunk in field_chunks]
|
||||
table_chunk.content += "\n" + self._separator + "\n" + "\n".join(field_contents)
|
||||
return table_chunk
|
||||
|
||||
def _similarity_search(
|
||||
self, query, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
"""Similar search."""
|
||||
with root_tracer.start_span(
|
||||
"dbgpt.rag.retriever.db_schema._similarity_search",
|
||||
parent_span_id,
|
||||
metadata={"query": query},
|
||||
):
|
||||
return await blocking_func_to_async_no_executor(
|
||||
self._index_store.similar_search, query, self._top_k, filters
|
||||
)
|
||||
table_chunks = self._table_vector_store_connector.similar_search_with_scores(
|
||||
query, self._top_k, 0, filters
|
||||
)
|
||||
|
||||
async def _aparse_db_summary(
|
||||
self, parent_span_id: Optional[str] = None
|
||||
) -> List[str]:
|
||||
"""Similar search."""
|
||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||
not_sep_chunks = [
|
||||
chunk for chunk in table_chunks if not chunk.metadata.get("separated")
|
||||
]
|
||||
separated_chunks = [
|
||||
chunk for chunk in table_chunks if chunk.metadata.get("separated")
|
||||
]
|
||||
if not separated_chunks:
|
||||
return not_sep_chunks
|
||||
|
||||
if not self._connector:
|
||||
raise RuntimeError("RDBMSConnector connection is required.")
|
||||
with root_tracer.start_span(
|
||||
"dbgpt.rag.retriever.db_schema._aparse_db_summary",
|
||||
parent_span_id,
|
||||
):
|
||||
return await blocking_func_to_async_no_executor(
|
||||
_parse_db_summary, self._connector
|
||||
)
|
||||
# Create tasks list
|
||||
tasks = [
|
||||
lambda c=chunk: self._retrieve_field(c, query) for chunk in separated_chunks
|
||||
]
|
||||
# Run tasks concurrently
|
||||
separated_result = run_tasks(tasks, concurrency_limit=3)
|
||||
|
||||
# Combine and return results
|
||||
return not_sep_chunks + separated_result
|
||||
|
Reference in New Issue
Block a user