fix:graph retrieve bug (#1884)

This commit is contained in:
Aries-ckt 2024-08-28 21:05:27 +08:00 committed by GitHub
parent 1cb7e35295
commit bb5d2d1f3d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 34 additions and 14 deletions

View File

@ -98,6 +98,7 @@ class ChatKnowledge(BaseChat):
top_k=retriever_top_k, top_k=retriever_top_k,
query_rewrite=query_rewrite, query_rewrite=query_rewrite,
rerank=reranker, rerank=reranker,
llm_model=self.llm_model,
) )
self.prompt_template.template_is_strict = False self.prompt_template.template_is_strict = False

View File

@ -4,6 +4,8 @@ from dbgpt._private.config import Config
from dbgpt.component import ComponentType from dbgpt.component import ComponentType
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG
from dbgpt.core import Chunk from dbgpt.core import Chunk
from dbgpt.model import DefaultLLMClient
from dbgpt.model.cluster import WorkerManagerFactory
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
from dbgpt.rag.retriever import EmbeddingRetriever, QueryRewrite, Ranker from dbgpt.rag.retriever import EmbeddingRetriever, QueryRewrite, Ranker
from dbgpt.rag.retriever.base import BaseRetriever from dbgpt.rag.retriever.base import BaseRetriever
@ -26,6 +28,7 @@ class KnowledgeSpaceRetriever(BaseRetriever):
top_k: Optional[int] = 4, top_k: Optional[int] = 4,
query_rewrite: Optional[QueryRewrite] = None, query_rewrite: Optional[QueryRewrite] = None,
rerank: Optional[Ranker] = None, rerank: Optional[Ranker] = None,
llm_model: Optional[str] = None,
): ):
""" """
Args: Args:
@ -40,6 +43,7 @@ class KnowledgeSpaceRetriever(BaseRetriever):
self._top_k = top_k self._top_k = top_k
self._query_rewrite = query_rewrite self._query_rewrite = query_rewrite
self._rerank = rerank self._rerank = rerank
self._llm_model = llm_model
embedding_factory = CFG.SYSTEM_APP.get_component( embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory "embedding_factory", EmbeddingFactory
) )
@ -50,9 +54,19 @@ class KnowledgeSpaceRetriever(BaseRetriever):
space_dao = KnowledgeSpaceDao() space_dao = KnowledgeSpaceDao()
space = space_dao.get_one({"id": space_id}) space = space_dao.get_one({"id": space_id})
config = VectorStoreConfig(name=space.name, embedding_fn=embedding_fn) worker_manager = CFG.SYSTEM_APP.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
llm_client = DefaultLLMClient(worker_manager=worker_manager)
config = VectorStoreConfig(
name=space.name,
embedding_fn=embedding_fn,
llm_client=llm_client,
llm_model=self._llm_model,
)
self._vector_store_connector = VectorStoreConnector( self._vector_store_connector = VectorStoreConnector(
vector_store_type=CFG.VECTOR_STORE_TYPE, vector_store_type=space.vector_type,
vector_store_config=config, vector_store_config=config,
) )
self._executor = CFG.SYSTEM_APP.get_component( self._executor = CFG.SYSTEM_APP.get_component(
@ -141,7 +155,6 @@ class KnowledgeSpaceRetriever(BaseRetriever):
Return: Return:
List[Chunk]: list of chunks with score. List[Chunk]: list of chunks with score.
""" """
candidates_with_score = await blocking_func_to_async( return await self._retriever_chain.aretrieve_with_scores(
self._executor, self._retrieve_with_score, query, score_threshold, filters query, score_threshold, filters
) )
return candidates_with_score

View File

@ -38,17 +38,20 @@ class RetrieverChain(BaseRetriever):
async def _aretrieve( async def _aretrieve(
self, query: str, filters: Optional[MetadataFilters] = None self, query: str, filters: Optional[MetadataFilters] = None
) -> List[Chunk]: ) -> List[Chunk]:
"""Retrieve knowledge chunks. """Async retrieve knowledge chunks.
Args: Args:
query (str): query text query (str): query text
filters: (Optional[MetadataFilters]) metadata filters. filters: (Optional[MetadataFilters]) metadata filters.
Return: Return:
List[Chunk]: list of chunks List[Chunk]: list of chunks
""" """
candidates = await blocking_func_to_async( for retriever in self._retrievers:
self._executor, self._retrieve, query, filters candidates = await retriever.aretrieve(
query=query, filters=filters
) )
if candidates:
return candidates return candidates
return []
def _retrieve_with_score( def _retrieve_with_score(
self, self,
@ -85,7 +88,10 @@ class RetrieverChain(BaseRetriever):
Return: Return:
List[Chunk]: list of chunks with score List[Chunk]: list of chunks with score
""" """
candidates_with_score = await blocking_func_to_async( for retriever in self._retrievers:
self._executor, self._retrieve_with_score, query, score_threshold, filters candidates_with_scores = await retriever.aretrieve_with_scores(
query=query, score_threshold=score_threshold, filters=filters
) )
return candidates_with_score if candidates_with_scores:
return candidates_with_scores
return []