DB-GPT/dbgpt/rag/retriever/embedding.py
明天 b124ecc10b
feat: (0.6)New UI (#1855)
Co-authored-by: 夏姜 <wenfengjiang.jwf@digital-engine.com>
Co-authored-by: aries_ckt <916701291@qq.com>
Co-authored-by: wb-lh513319 <wb-lh513319@alibaba-inc.com>
Co-authored-by: csunny <cfqsunny@163.com>
2024-08-21 17:37:45 +08:00

278 lines
10 KiB
Python

"""Embedding retriever."""
from functools import reduce
from typing import Any, Dict, List, Optional, cast
from dbgpt.core import Chunk
from dbgpt.rag.index.base import IndexStoreBase
from dbgpt.rag.retriever.base import BaseRetriever, RetrieverStrategy
from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker
from dbgpt.rag.retriever.rewrite import QueryRewrite
from dbgpt.storage.vector_store.filters import MetadataFilters
from dbgpt.util.chat_util import run_async_tasks
from dbgpt.util.executor_utils import blocking_func_to_async_no_executor
from dbgpt.util.tracer import root_tracer
class EmbeddingRetriever(BaseRetriever):
"""Embedding retriever."""
def __init__(
self,
index_store: IndexStoreBase,
top_k: int = 4,
query_rewrite: Optional[QueryRewrite] = None,
rerank: Optional[Ranker] = None,
retrieve_strategy: Optional[RetrieverStrategy] = RetrieverStrategy.EMBEDDING,
):
"""Create EmbeddingRetriever.
Args:
index_store(IndexStore): vector store connector
top_k (int): top k
query_rewrite (Optional[QueryRewrite]): query rewrite
rerank (Ranker): rerank
Examples:
.. code-block:: python
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
from dbgpt.rag.embedding.embedding_factory import (
DefaultEmbeddingFactory,
)
embedding_factory = DefaultEmbeddingFactory()
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
from dbgpt.storage.vector_store.connector import VectorStoreConnector
embedding_fn = embedding_factory.create(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
)
vector_name = "test"
config = ChromaVectorConfig(name=vector_name, embedding_fn=embedding_fn)
vector_store_connector = VectorStoreConnector(
vector_store_type="Chroma",
vector_store_config=config,
)
embedding_retriever = EmbeddingRetriever(
top_k=3, vector_store_connector=vector_store_connector
)
chunks = embedding_retriever.retrieve("your query text")
print(
f"embedding retriever results:{[chunk.content for chunk in chunks]}"
)
"""
self._top_k = top_k
self._query_rewrite = query_rewrite
self._index_store = index_store
self._rerank = rerank or DefaultRanker(self._top_k)
self._retrieve_strategy = retrieve_strategy
def load_document(self, chunks: List[Chunk], **kwargs: Dict[str, Any]) -> List[str]:
"""Load document in vector database.
Args:
chunks (List[Chunk]): document chunks.
Return:
List[str]: chunk ids.
"""
return self._index_store.load_document(chunks)
def _retrieve(
self, query: str, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Retrieve knowledge chunks.
Args:
query (str): query text
filters: metadata filters.
Return:
List[Chunk]: list of chunks
"""
queries = [query]
candidates = [
self._index_store.similar_search(query, self._top_k, filters)
for query in queries
]
res_candidates = cast(List[Chunk], reduce(lambda x, y: x + y, candidates))
return res_candidates
def _retrieve_with_score(
self,
query: str,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]:
"""Retrieve knowledge chunks with score.
Args:
query (str): query text
score_threshold (float): score threshold
filters: metadata filters.
Return:
List[Chunk]: list of chunks with score
"""
queries = [query]
candidates_with_score = [
self._index_store.similar_search_with_scores(
query, self._top_k, score_threshold, filters
)
for query in queries
]
new_candidates_with_score = cast(
List[Chunk], reduce(lambda x, y: x + y, candidates_with_score)
)
new_candidates_with_score = self._rerank.rank(new_candidates_with_score, query)
return new_candidates_with_score
async def _aretrieve(
self, query: str, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Retrieve knowledge chunks.
Args:
query (str): query text.
filters: metadata filters.
Return:
List[Chunk]: list of chunks
"""
queries = [query]
if self._query_rewrite:
candidates_tasks = [
self._similarity_search(
query, filters, root_tracer.get_current_span_id()
)
for query in queries
]
chunks = await self._run_async_tasks(candidates_tasks)
context = "\n".join([chunk.content for chunk in chunks])
new_queries = await self._query_rewrite.rewrite(
origin_query=query, context=context, nums=1
)
queries.extend(new_queries)
candidates = [
self._similarity_search(query, filters, root_tracer.get_current_span_id())
for query in queries
]
new_candidates = await run_async_tasks(tasks=candidates, concurrency_limit=1)
return new_candidates
async def _aretrieve_with_score(
self,
query: str,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]:
"""Retrieve knowledge chunks with score.
Args:
query (str): query text
score_threshold (float): score threshold
filters: metadata filters.
Return:
List[Chunk]: list of chunks with score
"""
queries = [query]
if self._query_rewrite:
with root_tracer.start_span(
"dbgpt.rag.retriever.embeddings.query_rewrite.similarity_search",
metadata={"query": query, "score_threshold": score_threshold},
):
candidates_tasks = [
self._similarity_search(
query, filters, root_tracer.get_current_span_id()
)
for query in queries
]
chunks = await self._run_async_tasks(candidates_tasks)
context = "\n".join([chunk.content for chunk in chunks])
with root_tracer.start_span(
"dbgpt.rag.retriever.embeddings.query_rewrite.rewrite",
metadata={"query": query, "context": context, "nums": 1},
):
new_queries = await self._query_rewrite.rewrite(
origin_query=query, context=context, nums=1
)
queries.extend(new_queries)
with root_tracer.start_span(
"dbgpt.rag.retriever.embeddings.similarity_search_with_score",
metadata={"query": query, "score_threshold": score_threshold},
):
candidates_with_score = [
self._similarity_search_with_score(
query, score_threshold, filters, root_tracer.get_current_span_id()
)
for query in queries
]
res_candidates_with_score = await run_async_tasks(
tasks=candidates_with_score, concurrency_limit=1
)
new_candidates_with_score = cast(
List[Chunk], reduce(lambda x, y: x + y, res_candidates_with_score)
)
with root_tracer.start_span(
"dbgpt.rag.retriever.embeddings.rerank",
metadata={
"query": query,
"score_threshold": score_threshold,
"rerank_cls": self._rerank.__class__.__name__,
},
):
new_candidates_with_score = await self._rerank.arank(
new_candidates_with_score, query
)
return new_candidates_with_score
async def _similarity_search(
self,
query,
filters: Optional[MetadataFilters] = None,
parent_span_id: Optional[str] = None,
) -> List[Chunk]:
"""Similar search."""
with root_tracer.start_span(
"dbgpt.rag.retriever.embeddings.similarity_search",
parent_span_id,
metadata={
"query": query,
},
):
return await blocking_func_to_async_no_executor(
self._index_store.similar_search, query, self._top_k, filters
)
async def _run_async_tasks(self, tasks) -> List[Chunk]:
"""Run async tasks."""
candidates = await run_async_tasks(tasks=tasks, concurrency_limit=1)
candidates = reduce(lambda x, y: x + y, candidates)
return cast(List[Chunk], candidates)
async def _similarity_search_with_score(
self,
query,
score_threshold,
filters: Optional[MetadataFilters] = None,
parent_span_id: Optional[str] = None,
) -> List[Chunk]:
"""Similar search with score."""
with root_tracer.start_span(
"dbgpt.rag.retriever.embeddings._do_similarity_search_with_score",
parent_span_id,
metadata={
"query": query,
"score_threshold": score_threshold,
},
):
return await self._index_store.asimilar_search_with_scores(
query, self._top_k, score_threshold, filters
)
@classmethod
def name(cls):
"""Return retriever name."""
return "embedding_retriever"