Feat rdb summary wide table (#2035)

Co-authored-by: dongzhancai1 <dongzhancai1@jd.com>
Co-authored-by: dong <dongzhancai@iie2.com>
This commit is contained in:
Cooper
2024-12-18 20:34:21 +08:00
committed by GitHub
parent 7f4b5e79cf
commit 9b0161e521
17 changed files with 948 additions and 243 deletions

View File

@@ -1,18 +1,23 @@
"""DBSchema retriever."""
import logging
import os
from typing import List, Optional
from functools import reduce
from typing import List, Optional, cast
from dbgpt._private.config import Config
from dbgpt.core import Chunk
from dbgpt.datasource.base import BaseConnector
from dbgpt.rag.index.base import IndexStoreBase
from dbgpt.rag.retriever.base import BaseRetriever
from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
from dbgpt.storage.vector_store.filters import MetadataFilters
from dbgpt.util.chat_util import run_async_tasks
from dbgpt.rag.summary.gdbms_db_summary import _parse_db_summary
from dbgpt.serve.rag.connector import VectorStoreConnector
from dbgpt.storage.vector_store.base import VectorStoreConfig
from dbgpt.storage.vector_store.filters import MetadataFilter, MetadataFilters
from dbgpt.util.chat_util import run_tasks
from dbgpt.util.executor_utils import blocking_func_to_async_no_executor
from dbgpt.util.tracer import root_tracer
logger = logging.getLogger(__name__)
CFG = Config()
class DBSchemaRetriever(BaseRetriever):
@@ -20,7 +25,9 @@ class DBSchemaRetriever(BaseRetriever):
def __init__(
self,
index_store: IndexStoreBase,
table_vector_store_connector: VectorStoreConnector,
field_vector_store_connector: VectorStoreConnector = None,
separator: str = "--table-field-separator--",
top_k: int = 4,
connector: Optional[BaseConnector] = None,
query_rewrite: bool = False,
@@ -30,7 +37,11 @@ class DBSchemaRetriever(BaseRetriever):
"""Create DBSchemaRetriever.
Args:
index_store(IndexStore): index connector
table_vector_store_connector: VectorStoreConnector
to load and retrieve table info.
field_vector_store_connector: VectorStoreConnector
to load and retrieve field info.
separator: field/table separator
top_k (int): top k
connector (Optional[BaseConnector]): RDBMSConnector.
query_rewrite (bool): query rewrite
@@ -70,34 +81,42 @@ class DBSchemaRetriever(BaseRetriever):
connector = _create_temporary_connection()
vector_store_config = ChromaVectorConfig(name="vector_store_name")
embedding_model_path = "{your_embedding_model_path}"
embedding_fn = embedding_factory.create(model_name=embedding_model_path)
config = ChromaVectorConfig(
persist_path=PILOT_PATH,
name="dbschema_rag_test",
embedding_fn=DefaultEmbeddingFactory(
default_model_name=os.path.join(
MODEL_PATH, "text2vec-large-chinese"
),
).create(),
vector_connector = VectorStoreConnector.from_default(
"Chroma",
vector_store_config=vector_store_config,
embedding_fn=embedding_fn,
)
vector_store = ChromaStore(config)
# get db struct retriever
retriever = DBSchemaRetriever(
top_k=3,
index_store=vector_store,
vector_store_connector=vector_connector,
connector=connector,
)
chunks = retriever.retrieve("show columns from table")
result = [chunk.content for chunk in chunks]
print(f"db struct rag example results:{result}")
"""
self._separator = separator
self._top_k = top_k
self._connector = connector
self._query_rewrite = query_rewrite
self._index_store = index_store
self._table_vector_store_connector = table_vector_store_connector
field_vector_store_config = VectorStoreConfig(
name=table_vector_store_connector.vector_store_config.name + "_field"
)
self._field_vector_store_connector = (
field_vector_store_connector
or VectorStoreConnector.from_default(
os.getenv("VECTOR_STORE_TYPE", "Chroma"),
self._table_vector_store_connector.current_embeddings,
vector_store_config=field_vector_store_config,
)
)
self._need_embeddings = False
if self._index_store:
if self._table_vector_store_connector:
self._need_embeddings = True
self._rerank = rerank or DefaultRanker(self._top_k)
@@ -114,15 +133,8 @@ class DBSchemaRetriever(BaseRetriever):
List[Chunk]: list of chunks
"""
if self._need_embeddings:
queries = [query]
candidates = [
self._index_store.similar_search(query, self._top_k, filters)
for query in queries
]
return cast(List[Chunk], reduce(lambda x, y: x + y, candidates))
return self._similarity_search(query, filters)
else:
if not self._connector:
raise RuntimeError("RDBMSConnector connection is required.")
table_summaries = _parse_db_summary(self._connector)
return [Chunk(content=table_summary) for table_summary in table_summaries]
@@ -156,30 +168,11 @@ class DBSchemaRetriever(BaseRetriever):
Returns:
List[Chunk]: list of chunks
"""
if self._need_embeddings:
queries = [query]
candidates = [
self._similarity_search(
query, filters, root_tracer.get_current_span_id()
)
for query in queries
]
result_candidates = await run_async_tasks(
tasks=candidates, concurrency_limit=1
)
return cast(List[Chunk], reduce(lambda x, y: x + y, result_candidates))
else:
from dbgpt.rag.summary.rdbms_db_summary import ( # noqa: F401
_parse_db_summary,
)
table_summaries = await run_async_tasks(
tasks=[self._aparse_db_summary(root_tracer.get_current_span_id())],
concurrency_limit=1,
)
return [
Chunk(content=table_summary) for table_summary in table_summaries[0]
]
return await blocking_func_to_async_no_executor(
func=self._retrieve,
query=query,
filters=filters,
)
async def _aretrieve_with_score(
self,
@@ -196,34 +189,40 @@ class DBSchemaRetriever(BaseRetriever):
"""
return await self._aretrieve(query, filters)
async def _similarity_search(
self,
query,
filters: Optional[MetadataFilters] = None,
parent_span_id: Optional[str] = None,
def _retrieve_field(self, table_chunk: Chunk, query) -> Chunk:
metadata = table_chunk.metadata
metadata["part"] = "field"
filters = [MetadataFilter(key=k, value=v) for k, v in metadata.items()]
field_chunks = self._field_vector_store_connector.similar_search_with_scores(
query, self._top_k, 0, MetadataFilters(filters=filters)
)
field_contents = [chunk.content for chunk in field_chunks]
table_chunk.content += "\n" + self._separator + "\n" + "\n".join(field_contents)
return table_chunk
def _similarity_search(
self, query, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Similar search."""
with root_tracer.start_span(
"dbgpt.rag.retriever.db_schema._similarity_search",
parent_span_id,
metadata={"query": query},
):
return await blocking_func_to_async_no_executor(
self._index_store.similar_search, query, self._top_k, filters
)
table_chunks = self._table_vector_store_connector.similar_search_with_scores(
query, self._top_k, 0, filters
)
async def _aparse_db_summary(
self, parent_span_id: Optional[str] = None
) -> List[str]:
"""Similar search."""
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
not_sep_chunks = [
chunk for chunk in table_chunks if not chunk.metadata.get("separated")
]
separated_chunks = [
chunk for chunk in table_chunks if chunk.metadata.get("separated")
]
if not separated_chunks:
return not_sep_chunks
if not self._connector:
raise RuntimeError("RDBMSConnector connection is required.")
with root_tracer.start_span(
"dbgpt.rag.retriever.db_schema._aparse_db_summary",
parent_span_id,
):
return await blocking_func_to_async_no_executor(
_parse_db_summary, self._connector
)
# Create tasks list
tasks = [
lambda c=chunk: self._retrieve_field(c, query) for chunk in separated_chunks
]
# Run tasks concurrently
separated_result = run_tasks(tasks, concurrency_limit=3)
# Combine and return results
return not_sep_chunks + separated_result