mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-31 15:47:05 +00:00
fix:graph retrieve bug (#1884)
This commit is contained in:
parent
1cb7e35295
commit
bb5d2d1f3d
@ -98,6 +98,7 @@ class ChatKnowledge(BaseChat):
|
||||
top_k=retriever_top_k,
|
||||
query_rewrite=query_rewrite,
|
||||
rerank=reranker,
|
||||
llm_model=self.llm_model,
|
||||
)
|
||||
|
||||
self.prompt_template.template_is_strict = False
|
||||
|
@ -4,6 +4,8 @@ from dbgpt._private.config import Config
|
||||
from dbgpt.component import ComponentType
|
||||
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.model import DefaultLLMClient
|
||||
from dbgpt.model.cluster import WorkerManagerFactory
|
||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||
from dbgpt.rag.retriever import EmbeddingRetriever, QueryRewrite, Ranker
|
||||
from dbgpt.rag.retriever.base import BaseRetriever
|
||||
@ -26,6 +28,7 @@ class KnowledgeSpaceRetriever(BaseRetriever):
|
||||
top_k: Optional[int] = 4,
|
||||
query_rewrite: Optional[QueryRewrite] = None,
|
||||
rerank: Optional[Ranker] = None,
|
||||
llm_model: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@ -40,6 +43,7 @@ class KnowledgeSpaceRetriever(BaseRetriever):
|
||||
self._top_k = top_k
|
||||
self._query_rewrite = query_rewrite
|
||||
self._rerank = rerank
|
||||
self._llm_model = llm_model
|
||||
embedding_factory = CFG.SYSTEM_APP.get_component(
|
||||
"embedding_factory", EmbeddingFactory
|
||||
)
|
||||
@ -50,9 +54,19 @@ class KnowledgeSpaceRetriever(BaseRetriever):
|
||||
|
||||
space_dao = KnowledgeSpaceDao()
|
||||
space = space_dao.get_one({"id": space_id})
|
||||
config = VectorStoreConfig(name=space.name, embedding_fn=embedding_fn)
|
||||
worker_manager = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||
).create()
|
||||
llm_client = DefaultLLMClient(worker_manager=worker_manager)
|
||||
config = VectorStoreConfig(
|
||||
name=space.name,
|
||||
embedding_fn=embedding_fn,
|
||||
llm_client=llm_client,
|
||||
llm_model=self._llm_model,
|
||||
)
|
||||
|
||||
self._vector_store_connector = VectorStoreConnector(
|
||||
vector_store_type=CFG.VECTOR_STORE_TYPE,
|
||||
vector_store_type=space.vector_type,
|
||||
vector_store_config=config,
|
||||
)
|
||||
self._executor = CFG.SYSTEM_APP.get_component(
|
||||
@ -141,7 +155,6 @@ class KnowledgeSpaceRetriever(BaseRetriever):
|
||||
Return:
|
||||
List[Chunk]: list of chunks with score.
|
||||
"""
|
||||
candidates_with_score = await blocking_func_to_async(
|
||||
self._executor, self._retrieve_with_score, query, score_threshold, filters
|
||||
return await self._retriever_chain.aretrieve_with_scores(
|
||||
query, score_threshold, filters
|
||||
)
|
||||
return candidates_with_score
|
||||
|
@ -38,17 +38,20 @@ class RetrieverChain(BaseRetriever):
|
||||
async def _aretrieve(
|
||||
self, query: str, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
"""Retrieve knowledge chunks.
|
||||
"""Async retrieve knowledge chunks.
|
||||
Args:
|
||||
query (str): query text
|
||||
filters: (Optional[MetadataFilters]) metadata filters.
|
||||
Return:
|
||||
List[Chunk]: list of chunks
|
||||
"""
|
||||
candidates = await blocking_func_to_async(
|
||||
self._executor, self._retrieve, query, filters
|
||||
)
|
||||
return candidates
|
||||
for retriever in self._retrievers:
|
||||
candidates = await retriever.aretrieve(
|
||||
query=query, filters=filters
|
||||
)
|
||||
if candidates:
|
||||
return candidates
|
||||
return []
|
||||
|
||||
def _retrieve_with_score(
|
||||
self,
|
||||
@ -85,7 +88,10 @@ class RetrieverChain(BaseRetriever):
|
||||
Return:
|
||||
List[Chunk]: list of chunks with score
|
||||
"""
|
||||
candidates_with_score = await blocking_func_to_async(
|
||||
self._executor, self._retrieve_with_score, query, score_threshold, filters
|
||||
)
|
||||
return candidates_with_score
|
||||
for retriever in self._retrievers:
|
||||
candidates_with_scores = await retriever.aretrieve_with_scores(
|
||||
query=query, score_threshold=score_threshold, filters=filters
|
||||
)
|
||||
if candidates_with_scores:
|
||||
return candidates_with_scores
|
||||
return []
|
||||
|
Loading…
Reference in New Issue
Block a user