fix:graph retrieve bug (#1884)

This commit is contained in:
Aries-ckt 2024-08-28 21:05:27 +08:00 committed by GitHub
parent 1cb7e35295
commit bb5d2d1f3d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 34 additions and 14 deletions

View File

@ -98,6 +98,7 @@ class ChatKnowledge(BaseChat):
top_k=retriever_top_k,
query_rewrite=query_rewrite,
rerank=reranker,
llm_model=self.llm_model,
)
self.prompt_template.template_is_strict = False

View File

@ -4,6 +4,8 @@ from dbgpt._private.config import Config
from dbgpt.component import ComponentType
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG
from dbgpt.core import Chunk
from dbgpt.model import DefaultLLMClient
from dbgpt.model.cluster import WorkerManagerFactory
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
from dbgpt.rag.retriever import EmbeddingRetriever, QueryRewrite, Ranker
from dbgpt.rag.retriever.base import BaseRetriever
@ -26,6 +28,7 @@ class KnowledgeSpaceRetriever(BaseRetriever):
top_k: Optional[int] = 4,
query_rewrite: Optional[QueryRewrite] = None,
rerank: Optional[Ranker] = None,
llm_model: Optional[str] = None,
):
"""
Args:
@ -40,6 +43,7 @@ class KnowledgeSpaceRetriever(BaseRetriever):
self._top_k = top_k
self._query_rewrite = query_rewrite
self._rerank = rerank
self._llm_model = llm_model
embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory
)
@ -50,9 +54,19 @@ class KnowledgeSpaceRetriever(BaseRetriever):
space_dao = KnowledgeSpaceDao()
space = space_dao.get_one({"id": space_id})
config = VectorStoreConfig(name=space.name, embedding_fn=embedding_fn)
worker_manager = CFG.SYSTEM_APP.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
llm_client = DefaultLLMClient(worker_manager=worker_manager)
config = VectorStoreConfig(
name=space.name,
embedding_fn=embedding_fn,
llm_client=llm_client,
llm_model=self._llm_model,
)
self._vector_store_connector = VectorStoreConnector(
vector_store_type=CFG.VECTOR_STORE_TYPE,
vector_store_type=space.vector_type,
vector_store_config=config,
)
self._executor = CFG.SYSTEM_APP.get_component(
@ -141,7 +155,6 @@ class KnowledgeSpaceRetriever(BaseRetriever):
Return:
List[Chunk]: list of chunks with score.
"""
candidates_with_score = await blocking_func_to_async(
self._executor, self._retrieve_with_score, query, score_threshold, filters
return await self._retriever_chain.aretrieve_with_scores(
query, score_threshold, filters
)
return candidates_with_score

View File

@ -38,17 +38,20 @@ class RetrieverChain(BaseRetriever):
async def _aretrieve(
self, query: str, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Retrieve knowledge chunks.
"""Async retrieve knowledge chunks.
Args:
query (str): query text
filters: (Optional[MetadataFilters]) metadata filters.
Return:
List[Chunk]: list of chunks
"""
candidates = await blocking_func_to_async(
self._executor, self._retrieve, query, filters
)
return candidates
for retriever in self._retrievers:
candidates = await retriever.aretrieve(
query=query, filters=filters
)
if candidates:
return candidates
return []
def _retrieve_with_score(
self,
@ -85,7 +88,10 @@ class RetrieverChain(BaseRetriever):
Return:
List[Chunk]: list of chunks with score
"""
candidates_with_score = await blocking_func_to_async(
self._executor, self._retrieve_with_score, query, score_threshold, filters
)
return candidates_with_score
for retriever in self._retrievers:
candidates_with_scores = await retriever.aretrieve_with_scores(
query=query, score_threshold=score_threshold, filters=filters
)
if candidates_with_scores:
return candidates_with_scores
return []