mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-06 02:46:40 +00:00
161 lines
5.4 KiB
Python
161 lines
5.4 KiB
Python
from typing import List, Optional
|
|
|
|
from dbgpt._private.config import Config
|
|
from dbgpt.component import ComponentType
|
|
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG
|
|
from dbgpt.core import Chunk
|
|
from dbgpt.model import DefaultLLMClient
|
|
from dbgpt.model.cluster import WorkerManagerFactory
|
|
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
|
from dbgpt.rag.retriever import EmbeddingRetriever, QueryRewrite, Ranker
|
|
from dbgpt.rag.retriever.base import BaseRetriever
|
|
from dbgpt.serve.rag.connector import VectorStoreConnector
|
|
from dbgpt.serve.rag.models.models import KnowledgeSpaceDao
|
|
from dbgpt.serve.rag.retriever.qa_retriever import QARetriever
|
|
from dbgpt.serve.rag.retriever.retriever_chain import RetrieverChain
|
|
from dbgpt.storage.vector_store.filters import MetadataFilters
|
|
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
|
|
|
|
CFG = Config()
|
|
|
|
|
|
class KnowledgeSpaceRetriever(BaseRetriever):
|
|
"""Knowledge Space retriever."""
|
|
|
|
def __init__(
|
|
self,
|
|
space_id: str = None,
|
|
top_k: Optional[int] = 4,
|
|
query_rewrite: Optional[QueryRewrite] = None,
|
|
rerank: Optional[Ranker] = None,
|
|
llm_model: Optional[str] = None,
|
|
):
|
|
"""
|
|
Args:
|
|
space_id (str): knowledge space name
|
|
top_k (Optional[int]): top k
|
|
query_rewrite: (Optional[QueryRewrite]) query rewrite
|
|
rerank: (Optional[Ranker]) rerank
|
|
"""
|
|
if space_id is None:
|
|
raise ValueError("space_id is required")
|
|
self._space_id = space_id
|
|
self._top_k = top_k
|
|
self._query_rewrite = query_rewrite
|
|
self._rerank = rerank
|
|
self._llm_model = llm_model
|
|
embedding_factory = CFG.SYSTEM_APP.get_component(
|
|
"embedding_factory", EmbeddingFactory
|
|
)
|
|
embedding_fn = embedding_factory.create(
|
|
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
|
|
)
|
|
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
|
|
|
space_dao = KnowledgeSpaceDao()
|
|
space = space_dao.get_one({"id": space_id})
|
|
worker_manager = CFG.SYSTEM_APP.get_component(
|
|
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
|
).create()
|
|
llm_client = DefaultLLMClient(worker_manager=worker_manager)
|
|
config = VectorStoreConfig(
|
|
name=space.name,
|
|
embedding_fn=embedding_fn,
|
|
llm_client=llm_client,
|
|
llm_model=self._llm_model,
|
|
)
|
|
|
|
self._vector_store_connector = VectorStoreConnector(
|
|
vector_store_type=space.vector_type,
|
|
vector_store_config=config,
|
|
)
|
|
self._executor = CFG.SYSTEM_APP.get_component(
|
|
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
|
).create()
|
|
|
|
self._retriever_chain = RetrieverChain(
|
|
retrievers=[
|
|
QARetriever(space_id=space_id, top_k=top_k, embedding_fn=embedding_fn),
|
|
EmbeddingRetriever(
|
|
index_store=self._vector_store_connector.index_client,
|
|
top_k=top_k,
|
|
query_rewrite=self._query_rewrite,
|
|
rerank=self._rerank,
|
|
),
|
|
],
|
|
executor=self._executor,
|
|
)
|
|
|
|
def _retrieve(
|
|
self, query: str, filters: Optional[MetadataFilters] = None
|
|
) -> List[Chunk]:
|
|
"""Retrieve knowledge chunks.
|
|
|
|
Args:
|
|
query (str): query text.
|
|
filters: (Optional[MetadataFilters]) metadata filters.
|
|
|
|
Return:
|
|
List[Chunk]: list of chunks
|
|
"""
|
|
candidates = self._retriever_chain.retrieve(query=query, filters=filters)
|
|
return candidates
|
|
|
|
def _retrieve_with_score(
|
|
self,
|
|
query: str,
|
|
score_threshold: float,
|
|
filters: Optional[MetadataFilters] = None,
|
|
) -> List[Chunk]:
|
|
"""Retrieve knowledge chunks with score.
|
|
|
|
Args:
|
|
query (str): query text
|
|
score_threshold (float): score threshold
|
|
filters: (Optional[MetadataFilters]) metadata filters.
|
|
|
|
Return:
|
|
List[Chunk]: list of chunks with score
|
|
"""
|
|
candidates_with_scores = self._retriever_chain.retrieve_with_scores(
|
|
query, score_threshold, filters
|
|
)
|
|
return candidates_with_scores
|
|
|
|
async def _aretrieve(
|
|
self, query: str, filters: Optional[MetadataFilters] = None
|
|
) -> List[Chunk]:
|
|
"""Retrieve knowledge chunks.
|
|
|
|
Args:
|
|
query (str): query text.
|
|
filters: (Optional[MetadataFilters]) metadata filters.
|
|
|
|
Return:
|
|
List[Chunk]: list of chunks
|
|
"""
|
|
candidates = await blocking_func_to_async(
|
|
self._executor, self._retrieve, query, filters
|
|
)
|
|
return candidates
|
|
|
|
async def _aretrieve_with_score(
|
|
self,
|
|
query: str,
|
|
score_threshold: float,
|
|
filters: Optional[MetadataFilters] = None,
|
|
) -> List[Chunk]:
|
|
"""Retrieve knowledge chunks with score.
|
|
|
|
Args:
|
|
query (str): query text.
|
|
score_threshold (float): score threshold.
|
|
filters: (Optional[MetadataFilters]) metadata filters.
|
|
|
|
Return:
|
|
List[Chunk]: list of chunks with score.
|
|
"""
|
|
return await self._retriever_chain.aretrieve_with_scores(
|
|
query, score_threshold, filters
|
|
)
|