DB-GPT/dbgpt/rag/retriever/db_schema.py
Cooper 9b0161e521
Feat rdb summary wide table (#2035)
Co-authored-by: dongzhancai1 <dongzhancai1@jd.com>
Co-authored-by: dong <dongzhancai@iie2.com>
2024-12-18 20:34:21 +08:00

229 lines
8.2 KiB
Python

"""DBSchema retriever."""
import logging
import os
from typing import List, Optional
from dbgpt._private.config import Config
from dbgpt.core import Chunk
from dbgpt.datasource.base import BaseConnector
from dbgpt.rag.retriever.base import BaseRetriever
from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker
from dbgpt.rag.summary.gdbms_db_summary import _parse_db_summary
from dbgpt.serve.rag.connector import VectorStoreConnector
from dbgpt.storage.vector_store.base import VectorStoreConfig
from dbgpt.storage.vector_store.filters import MetadataFilter, MetadataFilters
from dbgpt.util.chat_util import run_tasks
from dbgpt.util.executor_utils import blocking_func_to_async_no_executor
logger = logging.getLogger(__name__)
CFG = Config()
class DBSchemaRetriever(BaseRetriever):
"""DBSchema retriever."""
def __init__(
self,
table_vector_store_connector: VectorStoreConnector,
field_vector_store_connector: VectorStoreConnector = None,
separator: str = "--table-field-separator--",
top_k: int = 4,
connector: Optional[BaseConnector] = None,
query_rewrite: bool = False,
rerank: Optional[Ranker] = None,
**kwargs
):
"""Create DBSchemaRetriever.
Args:
table_vector_store_connector: VectorStoreConnector
to load and retrieve table info.
field_vector_store_connector: VectorStoreConnector
to load and retrieve field info.
separator: field/table separator
top_k (int): top k
connector (Optional[BaseConnector]): RDBMSConnector.
query_rewrite (bool): query rewrite
rerank (Ranker): rerank
Examples:
.. code-block:: python
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
from dbgpt.serve.rag.assembler.db_schema import DBSchemaAssembler
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
def _create_temporary_connection():
connect = SQLiteTempConnector.create_temporary_db()
connect.create_temp_tables(
{
"user": {
"columns": {
"id": "INTEGER PRIMARY KEY",
"name": "TEXT",
"age": "INTEGER",
},
"data": [
(1, "Tom", 10),
(2, "Jerry", 16),
(3, "Jack", 18),
(4, "Alice", 20),
(5, "Bob", 22),
],
}
}
)
return connect
connector = _create_temporary_connection()
vector_store_config = ChromaVectorConfig(name="vector_store_name")
embedding_model_path = "{your_embedding_model_path}"
embedding_fn = embedding_factory.create(model_name=embedding_model_path)
vector_connector = VectorStoreConnector.from_default(
"Chroma",
vector_store_config=vector_store_config,
embedding_fn=embedding_fn,
)
# get db struct retriever
retriever = DBSchemaRetriever(
top_k=3,
vector_store_connector=vector_connector,
connector=connector,
)
chunks = retriever.retrieve("show columns from table")
result = [chunk.content for chunk in chunks]
print(f"db struct rag example results:{result}")
"""
self._separator = separator
self._top_k = top_k
self._connector = connector
self._query_rewrite = query_rewrite
self._table_vector_store_connector = table_vector_store_connector
field_vector_store_config = VectorStoreConfig(
name=table_vector_store_connector.vector_store_config.name + "_field"
)
self._field_vector_store_connector = (
field_vector_store_connector
or VectorStoreConnector.from_default(
os.getenv("VECTOR_STORE_TYPE", "Chroma"),
self._table_vector_store_connector.current_embeddings,
vector_store_config=field_vector_store_config,
)
)
self._need_embeddings = False
if self._table_vector_store_connector:
self._need_embeddings = True
self._rerank = rerank or DefaultRanker(self._top_k)
def _retrieve(
self, query: str, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Retrieve knowledge chunks.
Args:
query (str): query text
filters: metadata filters.
Returns:
List[Chunk]: list of chunks
"""
if self._need_embeddings:
return self._similarity_search(query, filters)
else:
table_summaries = _parse_db_summary(self._connector)
return [Chunk(content=table_summary) for table_summary in table_summaries]
def _retrieve_with_score(
self,
query: str,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]:
"""Retrieve knowledge chunks with score.
Args:
query (str): query text
score_threshold (float): score threshold
filters: metadata filters.
Returns:
List[Chunk]: list of chunks
"""
return self._retrieve(query, filters)
async def _aretrieve(
self, query: str, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Retrieve knowledge chunks.
Args:
query (str): query text
filters: metadata filters.
Returns:
List[Chunk]: list of chunks
"""
return await blocking_func_to_async_no_executor(
func=self._retrieve,
query=query,
filters=filters,
)
async def _aretrieve_with_score(
self,
query: str,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]:
"""Retrieve knowledge chunks with score.
Args:
query (str): query text
score_threshold (float): score threshold
filters: metadata filters.
"""
return await self._aretrieve(query, filters)
def _retrieve_field(self, table_chunk: Chunk, query) -> Chunk:
metadata = table_chunk.metadata
metadata["part"] = "field"
filters = [MetadataFilter(key=k, value=v) for k, v in metadata.items()]
field_chunks = self._field_vector_store_connector.similar_search_with_scores(
query, self._top_k, 0, MetadataFilters(filters=filters)
)
field_contents = [chunk.content for chunk in field_chunks]
table_chunk.content += "\n" + self._separator + "\n" + "\n".join(field_contents)
return table_chunk
def _similarity_search(
self, query, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Similar search."""
table_chunks = self._table_vector_store_connector.similar_search_with_scores(
query, self._top_k, 0, filters
)
not_sep_chunks = [
chunk for chunk in table_chunks if not chunk.metadata.get("separated")
]
separated_chunks = [
chunk for chunk in table_chunks if chunk.metadata.get("separated")
]
if not separated_chunks:
return not_sep_chunks
# Create tasks list
tasks = [
lambda c=chunk: self._retrieve_field(c, query) for chunk in separated_chunks
]
# Run tasks concurrently
separated_result = run_tasks(tasks, concurrency_limit=3)
# Combine and return results
return not_sep_chunks + separated_result