diff --git a/packages/dbgpt-app/src/dbgpt_app/knowledge/service.py b/packages/dbgpt-app/src/dbgpt_app/knowledge/service.py index 798facff9..2befe7f65 100644 --- a/packages/dbgpt-app/src/dbgpt_app/knowledge/service.py +++ b/packages/dbgpt-app/src/dbgpt_app/knowledge/service.py @@ -340,11 +340,14 @@ class KnowledgeService: else 0.3 ) - if CFG.RERANK_MODEL is not None: - if top_k < int(CFG.RERANK_TOP_K) or top_k < 20: + app_config = CFG.SYSTEM_APP.config.configs.get("app_config") + rerank_top_k = app_config.rag.rerank_top_k + + if app_config.models.rerankers: + if top_k < int(rerank_top_k) or top_k < 20: # We use reranker, so if the top_k is less than 20, # we need to set it to 20 - top_k = max(int(CFG.RERANK_TOP_K), 20) + top_k = max(int(rerank_top_k), 20) knowledge_space_retriever = KnowledgeSpaceRetriever( space_id=space.id, top_k=top_k, system_app=CFG.SYSTEM_APP @@ -360,7 +363,8 @@ class KnowledgeService: ) recall_top_k = int(doc_recall_test_request.recall_top_k) - if CFG.RERANK_MODEL is not None: + + if app_config.models.rerankers: rerank_embeddings = RerankEmbeddingFactory.get_instance( CFG.SYSTEM_APP ).create() diff --git a/packages/dbgpt-core/src/dbgpt/agent/resource/knowledge.py b/packages/dbgpt-core/src/dbgpt/agent/resource/knowledge.py index aa18d0b53..03f6a42ad 100644 --- a/packages/dbgpt-core/src/dbgpt/agent/resource/knowledge.py +++ b/packages/dbgpt-core/src/dbgpt/agent/resource/knowledge.py @@ -5,7 +5,10 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type import cachetools +from dbgpt._private.config import Config from dbgpt.core import Chunk +from dbgpt.rag.embedding.embedding_factory import RerankEmbeddingFactory +from dbgpt.rag.retriever.rerank import RerankEmbeddingsRanker from dbgpt.util.cache_utils import cached from .base import Resource, ResourceParameters, ResourceType @@ -14,6 +17,8 @@ if TYPE_CHECKING: from dbgpt.rag.retriever.base import BaseRetriever from dbgpt.storage.vector_store.filters import MetadataFilters +CFG = Config() + @dataclasses.dataclass class RetrieverResourceParameters(ResourceParameters): @@ -32,6 +37,12 @@ class RetrieverResource(Resource[ResourceParameters]): """Create a new RetrieverResource.""" self._name = name self._retriever = retriever + app_config = CFG.SYSTEM_APP.config.configs.get("app_config") + rerank_embeddings = RerankEmbeddingFactory.get_instance(CFG.SYSTEM_APP).create() + self.need_rerank = bool(app_config.models.rerankers) + self.reranker = RerankEmbeddingsRanker( + rerank_embeddings, topk=app_config.rag.rerank_top_k + ) @property def name(self) -> str: @@ -77,6 +88,9 @@ class RetrieverResource(Resource[ResourceParameters]): if not question: raise ValueError("Question is required for knowledge resource.") chunks = await self.retrieve(question) + if self.need_rerank and len(chunks) > 1: + chunks = self.reranker.rank(candidates_with_scores=chunks, query=question) + content = "\n".join( [f"--{i}--:" + chunk.content for i, chunk in enumerate(chunks)] ) @@ -97,6 +111,9 @@ class RetrieverResource(Resource[ResourceParameters]): if not question: raise ValueError("Question is required for knowledge resource.") chunks = await self.retrieve(question) + if self.need_rerank and len(chunks) > 1: + chunks = self.reranker.rank(candidates_with_scores=chunks, query=question) + prompt_template = """Resources-{name}:\n {content}""" prompt_template_zh = """资源-{name}:\n {content}""" if lang == "en":