DB-GPT/dbgpt/serve/rag/retriever/knowledge_space.py
2024-08-28 21:05:27 +08:00

161 lines
5.4 KiB
Python

from typing import List, Optional
from dbgpt._private.config import Config
from dbgpt.component import ComponentType
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG
from dbgpt.core import Chunk
from dbgpt.model import DefaultLLMClient
from dbgpt.model.cluster import WorkerManagerFactory
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
from dbgpt.rag.retriever import EmbeddingRetriever, QueryRewrite, Ranker
from dbgpt.rag.retriever.base import BaseRetriever
from dbgpt.serve.rag.connector import VectorStoreConnector
from dbgpt.serve.rag.models.models import KnowledgeSpaceDao
from dbgpt.serve.rag.retriever.qa_retriever import QARetriever
from dbgpt.serve.rag.retriever.retriever_chain import RetrieverChain
from dbgpt.storage.vector_store.filters import MetadataFilters
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
CFG = Config()
class KnowledgeSpaceRetriever(BaseRetriever):
"""Knowledge Space retriever."""
def __init__(
self,
space_id: str = None,
top_k: Optional[int] = 4,
query_rewrite: Optional[QueryRewrite] = None,
rerank: Optional[Ranker] = None,
llm_model: Optional[str] = None,
):
"""
Args:
space_id (str): knowledge space name
top_k (Optional[int]): top k
query_rewrite: (Optional[QueryRewrite]) query rewrite
rerank: (Optional[Ranker]) rerank
"""
if space_id is None:
raise ValueError("space_id is required")
self._space_id = space_id
self._top_k = top_k
self._query_rewrite = query_rewrite
self._rerank = rerank
self._llm_model = llm_model
embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory
)
embedding_fn = embedding_factory.create(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
)
from dbgpt.storage.vector_store.base import VectorStoreConfig
space_dao = KnowledgeSpaceDao()
space = space_dao.get_one({"id": space_id})
worker_manager = CFG.SYSTEM_APP.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
llm_client = DefaultLLMClient(worker_manager=worker_manager)
config = VectorStoreConfig(
name=space.name,
embedding_fn=embedding_fn,
llm_client=llm_client,
llm_model=self._llm_model,
)
self._vector_store_connector = VectorStoreConnector(
vector_store_type=space.vector_type,
vector_store_config=config,
)
self._executor = CFG.SYSTEM_APP.get_component(
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create()
self._retriever_chain = RetrieverChain(
retrievers=[
QARetriever(space_id=space_id, top_k=top_k, embedding_fn=embedding_fn),
EmbeddingRetriever(
index_store=self._vector_store_connector.index_client,
top_k=top_k,
query_rewrite=self._query_rewrite,
rerank=self._rerank,
),
],
executor=self._executor,
)
def _retrieve(
self, query: str, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Retrieve knowledge chunks.
Args:
query (str): query text.
filters: (Optional[MetadataFilters]) metadata filters.
Return:
List[Chunk]: list of chunks
"""
candidates = self._retriever_chain.retrieve(query=query, filters=filters)
return candidates
def _retrieve_with_score(
self,
query: str,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]:
"""Retrieve knowledge chunks with score.
Args:
query (str): query text
score_threshold (float): score threshold
filters: (Optional[MetadataFilters]) metadata filters.
Return:
List[Chunk]: list of chunks with score
"""
candidates_with_scores = self._retriever_chain.retrieve_with_scores(
query, score_threshold, filters
)
return candidates_with_scores
async def _aretrieve(
self, query: str, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Retrieve knowledge chunks.
Args:
query (str): query text.
filters: (Optional[MetadataFilters]) metadata filters.
Return:
List[Chunk]: list of chunks
"""
candidates = await blocking_func_to_async(
self._executor, self._retrieve, query, filters
)
return candidates
async def _aretrieve_with_score(
self,
query: str,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]:
"""Retrieve knowledge chunks with score.
Args:
query (str): query text.
score_threshold (float): score threshold.
filters: (Optional[MetadataFilters]) metadata filters.
Return:
List[Chunk]: list of chunks with score.
"""
return await self._retriever_chain.aretrieve_with_scores(
query, score_threshold, filters
)