mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-09 04:08:10 +00:00
Co-authored-by: dongzhancai1 <dongzhancai1@jd.com> Co-authored-by: dong <dongzhancai@iie2.com>
229 lines
8.2 KiB
Python
229 lines
8.2 KiB
Python
"""DBSchema retriever."""
|
|
import logging
|
|
import os
|
|
from typing import List, Optional
|
|
|
|
from dbgpt._private.config import Config
|
|
from dbgpt.core import Chunk
|
|
from dbgpt.datasource.base import BaseConnector
|
|
from dbgpt.rag.retriever.base import BaseRetriever
|
|
from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker
|
|
from dbgpt.rag.summary.gdbms_db_summary import _parse_db_summary
|
|
from dbgpt.serve.rag.connector import VectorStoreConnector
|
|
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
|
from dbgpt.storage.vector_store.filters import MetadataFilter, MetadataFilters
|
|
from dbgpt.util.chat_util import run_tasks
|
|
from dbgpt.util.executor_utils import blocking_func_to_async_no_executor
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
CFG = Config()
|
|
|
|
|
|
class DBSchemaRetriever(BaseRetriever):
|
|
"""DBSchema retriever."""
|
|
|
|
def __init__(
|
|
self,
|
|
table_vector_store_connector: VectorStoreConnector,
|
|
field_vector_store_connector: VectorStoreConnector = None,
|
|
separator: str = "--table-field-separator--",
|
|
top_k: int = 4,
|
|
connector: Optional[BaseConnector] = None,
|
|
query_rewrite: bool = False,
|
|
rerank: Optional[Ranker] = None,
|
|
**kwargs
|
|
):
|
|
"""Create DBSchemaRetriever.
|
|
|
|
Args:
|
|
table_vector_store_connector: VectorStoreConnector
|
|
to load and retrieve table info.
|
|
field_vector_store_connector: VectorStoreConnector
|
|
to load and retrieve field info.
|
|
separator: field/table separator
|
|
top_k (int): top k
|
|
connector (Optional[BaseConnector]): RDBMSConnector.
|
|
query_rewrite (bool): query rewrite
|
|
rerank (Ranker): rerank
|
|
|
|
Examples:
|
|
.. code-block:: python
|
|
|
|
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
|
|
from dbgpt.serve.rag.assembler.db_schema import DBSchemaAssembler
|
|
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
|
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
|
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
|
|
|
|
|
|
def _create_temporary_connection():
|
|
connect = SQLiteTempConnector.create_temporary_db()
|
|
connect.create_temp_tables(
|
|
{
|
|
"user": {
|
|
"columns": {
|
|
"id": "INTEGER PRIMARY KEY",
|
|
"name": "TEXT",
|
|
"age": "INTEGER",
|
|
},
|
|
"data": [
|
|
(1, "Tom", 10),
|
|
(2, "Jerry", 16),
|
|
(3, "Jack", 18),
|
|
(4, "Alice", 20),
|
|
(5, "Bob", 22),
|
|
],
|
|
}
|
|
}
|
|
)
|
|
return connect
|
|
|
|
|
|
connector = _create_temporary_connection()
|
|
vector_store_config = ChromaVectorConfig(name="vector_store_name")
|
|
embedding_model_path = "{your_embedding_model_path}"
|
|
embedding_fn = embedding_factory.create(model_name=embedding_model_path)
|
|
vector_connector = VectorStoreConnector.from_default(
|
|
"Chroma",
|
|
vector_store_config=vector_store_config,
|
|
embedding_fn=embedding_fn,
|
|
)
|
|
# get db struct retriever
|
|
retriever = DBSchemaRetriever(
|
|
top_k=3,
|
|
vector_store_connector=vector_connector,
|
|
connector=connector,
|
|
)
|
|
chunks = retriever.retrieve("show columns from table")
|
|
result = [chunk.content for chunk in chunks]
|
|
print(f"db struct rag example results:{result}")
|
|
"""
|
|
self._separator = separator
|
|
self._top_k = top_k
|
|
self._connector = connector
|
|
self._query_rewrite = query_rewrite
|
|
self._table_vector_store_connector = table_vector_store_connector
|
|
field_vector_store_config = VectorStoreConfig(
|
|
name=table_vector_store_connector.vector_store_config.name + "_field"
|
|
)
|
|
self._field_vector_store_connector = (
|
|
field_vector_store_connector
|
|
or VectorStoreConnector.from_default(
|
|
os.getenv("VECTOR_STORE_TYPE", "Chroma"),
|
|
self._table_vector_store_connector.current_embeddings,
|
|
vector_store_config=field_vector_store_config,
|
|
)
|
|
)
|
|
self._need_embeddings = False
|
|
if self._table_vector_store_connector:
|
|
self._need_embeddings = True
|
|
self._rerank = rerank or DefaultRanker(self._top_k)
|
|
|
|
def _retrieve(
|
|
self, query: str, filters: Optional[MetadataFilters] = None
|
|
) -> List[Chunk]:
|
|
"""Retrieve knowledge chunks.
|
|
|
|
Args:
|
|
query (str): query text
|
|
filters: metadata filters.
|
|
|
|
Returns:
|
|
List[Chunk]: list of chunks
|
|
"""
|
|
if self._need_embeddings:
|
|
return self._similarity_search(query, filters)
|
|
else:
|
|
table_summaries = _parse_db_summary(self._connector)
|
|
return [Chunk(content=table_summary) for table_summary in table_summaries]
|
|
|
|
def _retrieve_with_score(
|
|
self,
|
|
query: str,
|
|
score_threshold: float,
|
|
filters: Optional[MetadataFilters] = None,
|
|
) -> List[Chunk]:
|
|
"""Retrieve knowledge chunks with score.
|
|
|
|
Args:
|
|
query (str): query text
|
|
score_threshold (float): score threshold
|
|
filters: metadata filters.
|
|
|
|
Returns:
|
|
List[Chunk]: list of chunks
|
|
"""
|
|
return self._retrieve(query, filters)
|
|
|
|
async def _aretrieve(
|
|
self, query: str, filters: Optional[MetadataFilters] = None
|
|
) -> List[Chunk]:
|
|
"""Retrieve knowledge chunks.
|
|
|
|
Args:
|
|
query (str): query text
|
|
filters: metadata filters.
|
|
|
|
Returns:
|
|
List[Chunk]: list of chunks
|
|
"""
|
|
return await blocking_func_to_async_no_executor(
|
|
func=self._retrieve,
|
|
query=query,
|
|
filters=filters,
|
|
)
|
|
|
|
async def _aretrieve_with_score(
|
|
self,
|
|
query: str,
|
|
score_threshold: float,
|
|
filters: Optional[MetadataFilters] = None,
|
|
) -> List[Chunk]:
|
|
"""Retrieve knowledge chunks with score.
|
|
|
|
Args:
|
|
query (str): query text
|
|
score_threshold (float): score threshold
|
|
filters: metadata filters.
|
|
"""
|
|
return await self._aretrieve(query, filters)
|
|
|
|
def _retrieve_field(self, table_chunk: Chunk, query) -> Chunk:
|
|
metadata = table_chunk.metadata
|
|
metadata["part"] = "field"
|
|
filters = [MetadataFilter(key=k, value=v) for k, v in metadata.items()]
|
|
field_chunks = self._field_vector_store_connector.similar_search_with_scores(
|
|
query, self._top_k, 0, MetadataFilters(filters=filters)
|
|
)
|
|
field_contents = [chunk.content for chunk in field_chunks]
|
|
table_chunk.content += "\n" + self._separator + "\n" + "\n".join(field_contents)
|
|
return table_chunk
|
|
|
|
def _similarity_search(
|
|
self, query, filters: Optional[MetadataFilters] = None
|
|
) -> List[Chunk]:
|
|
"""Similar search."""
|
|
table_chunks = self._table_vector_store_connector.similar_search_with_scores(
|
|
query, self._top_k, 0, filters
|
|
)
|
|
|
|
not_sep_chunks = [
|
|
chunk for chunk in table_chunks if not chunk.metadata.get("separated")
|
|
]
|
|
separated_chunks = [
|
|
chunk for chunk in table_chunks if chunk.metadata.get("separated")
|
|
]
|
|
if not separated_chunks:
|
|
return not_sep_chunks
|
|
|
|
# Create tasks list
|
|
tasks = [
|
|
lambda c=chunk: self._retrieve_field(c, query) for chunk in separated_chunks
|
|
]
|
|
# Run tasks concurrently
|
|
separated_result = run_tasks(tasks, concurrency_limit=3)
|
|
|
|
# Combine and return results
|
|
return not_sep_chunks + separated_result
|