mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-03 01:54:44 +00:00
feat(RAG):add metadata properties filters (#1395)
This commit is contained in:
@@ -8,6 +8,7 @@ from dbgpt.rag.retriever.base import BaseRetriever
|
||||
from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker
|
||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||
from dbgpt.util.chat_util import run_async_tasks
|
||||
|
||||
|
||||
@@ -93,11 +94,14 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
self._need_embeddings = True
|
||||
self._rerank = rerank or DefaultRanker(self._top_k)
|
||||
|
||||
def _retrieve(self, query: str) -> List[Chunk]:
|
||||
def _retrieve(
|
||||
self, query: str, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
"""Retrieve knowledge chunks.
|
||||
|
||||
Args:
|
||||
query (str): query text
|
||||
filters: metadata filters.
|
||||
|
||||
Returns:
|
||||
List[Chunk]: list of chunks
|
||||
@@ -105,7 +109,7 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
if self._need_embeddings:
|
||||
queries = [query]
|
||||
candidates = [
|
||||
self._vector_store_connector.similar_search(query, self._top_k)
|
||||
self._vector_store_connector.similar_search(query, self._top_k, filters)
|
||||
for query in queries
|
||||
]
|
||||
return cast(List[Chunk], reduce(lambda x, y: x + y, candidates))
|
||||
@@ -115,30 +119,39 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
table_summaries = _parse_db_summary(self._connector)
|
||||
return [Chunk(content=table_summary) for table_summary in table_summaries]
|
||||
|
||||
def _retrieve_with_score(self, query: str, score_threshold: float) -> List[Chunk]:
|
||||
def _retrieve_with_score(
|
||||
self,
|
||||
query: str,
|
||||
score_threshold: float,
|
||||
filters: Optional[MetadataFilters] = None,
|
||||
) -> List[Chunk]:
|
||||
"""Retrieve knowledge chunks with score.
|
||||
|
||||
Args:
|
||||
query (str): query text
|
||||
score_threshold (float): score threshold
|
||||
filters: metadata filters.
|
||||
|
||||
Returns:
|
||||
List[Chunk]: list of chunks
|
||||
"""
|
||||
return self._retrieve(query)
|
||||
return self._retrieve(query, filters)
|
||||
|
||||
async def _aretrieve(self, query: str) -> List[Chunk]:
|
||||
async def _aretrieve(
|
||||
self, query: str, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
"""Retrieve knowledge chunks.
|
||||
|
||||
Args:
|
||||
query (str): query text
|
||||
filters: metadata filters.
|
||||
|
||||
Returns:
|
||||
List[Chunk]: list of chunks
|
||||
"""
|
||||
if self._need_embeddings:
|
||||
queries = [query]
|
||||
candidates = [self._similarity_search(query) for query in queries]
|
||||
candidates = [self._similarity_search(query, filters) for query in queries]
|
||||
result_candidates = await run_async_tasks(
|
||||
tasks=candidates, concurrency_limit=1
|
||||
)
|
||||
@@ -154,22 +167,25 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
return [Chunk(content=table_summary) for table_summary in table_summaries]
|
||||
|
||||
async def _aretrieve_with_score(
|
||||
self, query: str, score_threshold: float
|
||||
self,
|
||||
query: str,
|
||||
score_threshold: float,
|
||||
filters: Optional[MetadataFilters] = None,
|
||||
) -> List[Chunk]:
|
||||
"""Retrieve knowledge chunks with score.
|
||||
|
||||
Args:
|
||||
query (str): query text
|
||||
score_threshold (float): score threshold
|
||||
filters: metadata filters.
|
||||
"""
|
||||
return await self._aretrieve(query)
|
||||
return await self._aretrieve(query, filters)
|
||||
|
||||
async def _similarity_search(self, query) -> List[Chunk]:
|
||||
async def _similarity_search(
|
||||
self, query, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
"""Similar search."""
|
||||
return self._vector_store_connector.similar_search(
|
||||
query,
|
||||
self._top_k,
|
||||
)
|
||||
return self._vector_store_connector.similar_search(query, self._top_k, filters)
|
||||
|
||||
async def _aparse_db_summary(self) -> List[str]:
|
||||
"""Similar search."""
|
||||
|
Reference in New Issue
Block a user