feat(RAG):add metadata properties filters (#1395)

This commit is contained in:
Aries-ckt
2024-04-10 14:33:24 +08:00
committed by GitHub
parent 0f2b46da62
commit 37e7c0151b
26 changed files with 619 additions and 166 deletions

View File

@@ -8,6 +8,7 @@ from dbgpt.rag.retriever.base import BaseRetriever
from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.storage.vector_store.filters import MetadataFilters
from dbgpt.util.chat_util import run_async_tasks
@@ -93,11 +94,14 @@ class DBSchemaRetriever(BaseRetriever):
self._need_embeddings = True
self._rerank = rerank or DefaultRanker(self._top_k)
def _retrieve(self, query: str) -> List[Chunk]:
def _retrieve(
self, query: str, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Retrieve knowledge chunks.
Args:
query (str): query text
filters: metadata filters.
Returns:
List[Chunk]: list of chunks
@@ -105,7 +109,7 @@ class DBSchemaRetriever(BaseRetriever):
if self._need_embeddings:
queries = [query]
candidates = [
self._vector_store_connector.similar_search(query, self._top_k)
self._vector_store_connector.similar_search(query, self._top_k, filters)
for query in queries
]
return cast(List[Chunk], reduce(lambda x, y: x + y, candidates))
@@ -115,30 +119,39 @@ class DBSchemaRetriever(BaseRetriever):
table_summaries = _parse_db_summary(self._connector)
return [Chunk(content=table_summary) for table_summary in table_summaries]
def _retrieve_with_score(self, query: str, score_threshold: float) -> List[Chunk]:
def _retrieve_with_score(
self,
query: str,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]:
"""Retrieve knowledge chunks with score.
Args:
query (str): query text
score_threshold (float): score threshold
filters: metadata filters.
Returns:
List[Chunk]: list of chunks
"""
return self._retrieve(query)
return self._retrieve(query, filters)
async def _aretrieve(self, query: str) -> List[Chunk]:
async def _aretrieve(
self, query: str, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Retrieve knowledge chunks.
Args:
query (str): query text
filters: metadata filters.
Returns:
List[Chunk]: list of chunks
"""
if self._need_embeddings:
queries = [query]
candidates = [self._similarity_search(query) for query in queries]
candidates = [self._similarity_search(query, filters) for query in queries]
result_candidates = await run_async_tasks(
tasks=candidates, concurrency_limit=1
)
@@ -154,22 +167,25 @@ class DBSchemaRetriever(BaseRetriever):
return [Chunk(content=table_summary) for table_summary in table_summaries]
async def _aretrieve_with_score(
self, query: str, score_threshold: float
self,
query: str,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]:
"""Retrieve knowledge chunks with score.
Args:
query (str): query text
score_threshold (float): score threshold
filters: metadata filters.
"""
return await self._aretrieve(query)
return await self._aretrieve(query, filters)
async def _similarity_search(self, query) -> List[Chunk]:
async def _similarity_search(
self, query, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Similar search."""
return self._vector_store_connector.similar_search(
query,
self._top_k,
)
return self._vector_store_connector.similar_search(query, self._top_k, filters)
async def _aparse_db_summary(self) -> List[str]:
"""Similar search."""