mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-27 13:57:46 +00:00
Co-authored-by: 夏姜 <wenfengjiang.jwf@digital-engine.com> Co-authored-by: aries_ckt <916701291@qq.com> Co-authored-by: wb-lh513319 <wb-lh513319@alibaba-inc.com> Co-authored-by: csunny <cfqsunny@163.com>
278 lines
10 KiB
Python
278 lines
10 KiB
Python
"""Embedding retriever."""
|
|
|
|
from functools import reduce
|
|
from typing import Any, Dict, List, Optional, cast
|
|
|
|
from dbgpt.core import Chunk
|
|
from dbgpt.rag.index.base import IndexStoreBase
|
|
from dbgpt.rag.retriever.base import BaseRetriever, RetrieverStrategy
|
|
from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker
|
|
from dbgpt.rag.retriever.rewrite import QueryRewrite
|
|
from dbgpt.storage.vector_store.filters import MetadataFilters
|
|
from dbgpt.util.chat_util import run_async_tasks
|
|
from dbgpt.util.executor_utils import blocking_func_to_async_no_executor
|
|
from dbgpt.util.tracer import root_tracer
|
|
|
|
|
|
class EmbeddingRetriever(BaseRetriever):
|
|
"""Embedding retriever."""
|
|
|
|
def __init__(
|
|
self,
|
|
index_store: IndexStoreBase,
|
|
top_k: int = 4,
|
|
query_rewrite: Optional[QueryRewrite] = None,
|
|
rerank: Optional[Ranker] = None,
|
|
retrieve_strategy: Optional[RetrieverStrategy] = RetrieverStrategy.EMBEDDING,
|
|
):
|
|
"""Create EmbeddingRetriever.
|
|
|
|
Args:
|
|
index_store(IndexStore): vector store connector
|
|
top_k (int): top k
|
|
query_rewrite (Optional[QueryRewrite]): query rewrite
|
|
rerank (Ranker): rerank
|
|
|
|
Examples:
|
|
.. code-block:: python
|
|
|
|
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
|
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
|
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
|
|
from dbgpt.rag.embedding.embedding_factory import (
|
|
DefaultEmbeddingFactory,
|
|
)
|
|
|
|
embedding_factory = DefaultEmbeddingFactory()
|
|
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
|
|
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
|
|
|
embedding_fn = embedding_factory.create(
|
|
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
|
|
)
|
|
vector_name = "test"
|
|
config = ChromaVectorConfig(name=vector_name, embedding_fn=embedding_fn)
|
|
vector_store_connector = VectorStoreConnector(
|
|
vector_store_type="Chroma",
|
|
vector_store_config=config,
|
|
)
|
|
embedding_retriever = EmbeddingRetriever(
|
|
top_k=3, vector_store_connector=vector_store_connector
|
|
)
|
|
chunks = embedding_retriever.retrieve("your query text")
|
|
print(
|
|
f"embedding retriever results:{[chunk.content for chunk in chunks]}"
|
|
)
|
|
"""
|
|
self._top_k = top_k
|
|
self._query_rewrite = query_rewrite
|
|
self._index_store = index_store
|
|
self._rerank = rerank or DefaultRanker(self._top_k)
|
|
self._retrieve_strategy = retrieve_strategy
|
|
|
|
def load_document(self, chunks: List[Chunk], **kwargs: Dict[str, Any]) -> List[str]:
|
|
"""Load document in vector database.
|
|
|
|
Args:
|
|
chunks (List[Chunk]): document chunks.
|
|
Return:
|
|
List[str]: chunk ids.
|
|
"""
|
|
return self._index_store.load_document(chunks)
|
|
|
|
def _retrieve(
|
|
self, query: str, filters: Optional[MetadataFilters] = None
|
|
) -> List[Chunk]:
|
|
"""Retrieve knowledge chunks.
|
|
|
|
Args:
|
|
query (str): query text
|
|
filters: metadata filters.
|
|
Return:
|
|
List[Chunk]: list of chunks
|
|
"""
|
|
queries = [query]
|
|
candidates = [
|
|
self._index_store.similar_search(query, self._top_k, filters)
|
|
for query in queries
|
|
]
|
|
res_candidates = cast(List[Chunk], reduce(lambda x, y: x + y, candidates))
|
|
return res_candidates
|
|
|
|
def _retrieve_with_score(
|
|
self,
|
|
query: str,
|
|
score_threshold: float,
|
|
filters: Optional[MetadataFilters] = None,
|
|
) -> List[Chunk]:
|
|
"""Retrieve knowledge chunks with score.
|
|
|
|
Args:
|
|
query (str): query text
|
|
score_threshold (float): score threshold
|
|
filters: metadata filters.
|
|
Return:
|
|
List[Chunk]: list of chunks with score
|
|
"""
|
|
queries = [query]
|
|
candidates_with_score = [
|
|
self._index_store.similar_search_with_scores(
|
|
query, self._top_k, score_threshold, filters
|
|
)
|
|
for query in queries
|
|
]
|
|
new_candidates_with_score = cast(
|
|
List[Chunk], reduce(lambda x, y: x + y, candidates_with_score)
|
|
)
|
|
new_candidates_with_score = self._rerank.rank(new_candidates_with_score, query)
|
|
return new_candidates_with_score
|
|
|
|
async def _aretrieve(
|
|
self, query: str, filters: Optional[MetadataFilters] = None
|
|
) -> List[Chunk]:
|
|
"""Retrieve knowledge chunks.
|
|
|
|
Args:
|
|
query (str): query text.
|
|
filters: metadata filters.
|
|
Return:
|
|
List[Chunk]: list of chunks
|
|
"""
|
|
queries = [query]
|
|
if self._query_rewrite:
|
|
candidates_tasks = [
|
|
self._similarity_search(
|
|
query, filters, root_tracer.get_current_span_id()
|
|
)
|
|
for query in queries
|
|
]
|
|
chunks = await self._run_async_tasks(candidates_tasks)
|
|
context = "\n".join([chunk.content for chunk in chunks])
|
|
new_queries = await self._query_rewrite.rewrite(
|
|
origin_query=query, context=context, nums=1
|
|
)
|
|
queries.extend(new_queries)
|
|
candidates = [
|
|
self._similarity_search(query, filters, root_tracer.get_current_span_id())
|
|
for query in queries
|
|
]
|
|
new_candidates = await run_async_tasks(tasks=candidates, concurrency_limit=1)
|
|
return new_candidates
|
|
|
|
async def _aretrieve_with_score(
|
|
self,
|
|
query: str,
|
|
score_threshold: float,
|
|
filters: Optional[MetadataFilters] = None,
|
|
) -> List[Chunk]:
|
|
"""Retrieve knowledge chunks with score.
|
|
|
|
Args:
|
|
query (str): query text
|
|
score_threshold (float): score threshold
|
|
filters: metadata filters.
|
|
Return:
|
|
List[Chunk]: list of chunks with score
|
|
"""
|
|
queries = [query]
|
|
if self._query_rewrite:
|
|
with root_tracer.start_span(
|
|
"dbgpt.rag.retriever.embeddings.query_rewrite.similarity_search",
|
|
metadata={"query": query, "score_threshold": score_threshold},
|
|
):
|
|
candidates_tasks = [
|
|
self._similarity_search(
|
|
query, filters, root_tracer.get_current_span_id()
|
|
)
|
|
for query in queries
|
|
]
|
|
chunks = await self._run_async_tasks(candidates_tasks)
|
|
context = "\n".join([chunk.content for chunk in chunks])
|
|
with root_tracer.start_span(
|
|
"dbgpt.rag.retriever.embeddings.query_rewrite.rewrite",
|
|
metadata={"query": query, "context": context, "nums": 1},
|
|
):
|
|
new_queries = await self._query_rewrite.rewrite(
|
|
origin_query=query, context=context, nums=1
|
|
)
|
|
queries.extend(new_queries)
|
|
|
|
with root_tracer.start_span(
|
|
"dbgpt.rag.retriever.embeddings.similarity_search_with_score",
|
|
metadata={"query": query, "score_threshold": score_threshold},
|
|
):
|
|
candidates_with_score = [
|
|
self._similarity_search_with_score(
|
|
query, score_threshold, filters, root_tracer.get_current_span_id()
|
|
)
|
|
for query in queries
|
|
]
|
|
res_candidates_with_score = await run_async_tasks(
|
|
tasks=candidates_with_score, concurrency_limit=1
|
|
)
|
|
new_candidates_with_score = cast(
|
|
List[Chunk], reduce(lambda x, y: x + y, res_candidates_with_score)
|
|
)
|
|
|
|
with root_tracer.start_span(
|
|
"dbgpt.rag.retriever.embeddings.rerank",
|
|
metadata={
|
|
"query": query,
|
|
"score_threshold": score_threshold,
|
|
"rerank_cls": self._rerank.__class__.__name__,
|
|
},
|
|
):
|
|
new_candidates_with_score = await self._rerank.arank(
|
|
new_candidates_with_score, query
|
|
)
|
|
return new_candidates_with_score
|
|
|
|
async def _similarity_search(
|
|
self,
|
|
query,
|
|
filters: Optional[MetadataFilters] = None,
|
|
parent_span_id: Optional[str] = None,
|
|
) -> List[Chunk]:
|
|
"""Similar search."""
|
|
with root_tracer.start_span(
|
|
"dbgpt.rag.retriever.embeddings.similarity_search",
|
|
parent_span_id,
|
|
metadata={
|
|
"query": query,
|
|
},
|
|
):
|
|
return await blocking_func_to_async_no_executor(
|
|
self._index_store.similar_search, query, self._top_k, filters
|
|
)
|
|
|
|
async def _run_async_tasks(self, tasks) -> List[Chunk]:
|
|
"""Run async tasks."""
|
|
candidates = await run_async_tasks(tasks=tasks, concurrency_limit=1)
|
|
candidates = reduce(lambda x, y: x + y, candidates)
|
|
return cast(List[Chunk], candidates)
|
|
|
|
async def _similarity_search_with_score(
|
|
self,
|
|
query,
|
|
score_threshold,
|
|
filters: Optional[MetadataFilters] = None,
|
|
parent_span_id: Optional[str] = None,
|
|
) -> List[Chunk]:
|
|
"""Similar search with score."""
|
|
with root_tracer.start_span(
|
|
"dbgpt.rag.retriever.embeddings._do_similarity_search_with_score",
|
|
parent_span_id,
|
|
metadata={
|
|
"query": query,
|
|
"score_threshold": score_threshold,
|
|
},
|
|
):
|
|
return await self._index_store.asimilar_search_with_scores(
|
|
query, self._top_k, score_threshold, filters
|
|
)
|
|
|
|
@classmethod
|
|
def name(cls):
|
|
"""Return retriever name."""
|
|
return "embedding_retriever"
|