mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-01 16:18:27 +00:00
fix(reranker) The rerank model is used during the knowledge base recall for chat scenarios and recall test scenarios (#2638)
fix https://github.com/eosphoros-ai/DB-GPT/issues/2636 # Description Currently, the reranker model is not used for knowledge base recalls, In the recall test function, the CFG.RERANK_MODEL is always none In chat scenarios, knowledge base recalls are also not rerankered # How Has This Been Tested?   # Checklist: - [x] My code follows the style guidelines of this project - [x] I have already rebased the commits and make the commit message conform to the project standard. - [x] I have performed a self-review of my own code - [ ] I have commented my code, particularly in hard-to-understand areas - [ ] I have made corresponding changes to the documentation - [ ] Any dependent changes have been merged and published in downstream modules
This commit is contained in:
commit
a6680610b9
@ -340,11 +340,14 @@ class KnowledgeService:
|
||||
else 0.3
|
||||
)
|
||||
|
||||
if CFG.RERANK_MODEL is not None:
|
||||
if top_k < int(CFG.RERANK_TOP_K) or top_k < 20:
|
||||
app_config = CFG.SYSTEM_APP.config.configs.get("app_config")
|
||||
rerank_top_k = app_config.rag.rerank_top_k
|
||||
|
||||
if app_config.models.rerankers:
|
||||
if top_k < int(rerank_top_k) or top_k < 20:
|
||||
# We use reranker, so if the top_k is less than 20,
|
||||
# we need to set it to 20
|
||||
top_k = max(int(CFG.RERANK_TOP_K), 20)
|
||||
top_k = max(int(rerank_top_k), 20)
|
||||
|
||||
knowledge_space_retriever = KnowledgeSpaceRetriever(
|
||||
space_id=space.id, top_k=top_k, system_app=CFG.SYSTEM_APP
|
||||
@ -360,7 +363,8 @@ class KnowledgeService:
|
||||
)
|
||||
|
||||
recall_top_k = int(doc_recall_test_request.recall_top_k)
|
||||
if CFG.RERANK_MODEL is not None:
|
||||
|
||||
if app_config.models.rerankers:
|
||||
rerank_embeddings = RerankEmbeddingFactory.get_instance(
|
||||
CFG.SYSTEM_APP
|
||||
).create()
|
||||
|
@ -5,7 +5,10 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import cachetools
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.rag.embedding.embedding_factory import RerankEmbeddingFactory
|
||||
from dbgpt.rag.retriever.rerank import RerankEmbeddingsRanker
|
||||
from dbgpt.util.cache_utils import cached
|
||||
|
||||
from .base import Resource, ResourceParameters, ResourceType
|
||||
@ -14,6 +17,8 @@ if TYPE_CHECKING:
|
||||
from dbgpt.rag.retriever.base import BaseRetriever
|
||||
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RetrieverResourceParameters(ResourceParameters):
|
||||
@ -32,6 +37,12 @@ class RetrieverResource(Resource[ResourceParameters]):
|
||||
"""Create a new RetrieverResource."""
|
||||
self._name = name
|
||||
self._retriever = retriever
|
||||
app_config = CFG.SYSTEM_APP.config.configs.get("app_config")
|
||||
rerank_embeddings = RerankEmbeddingFactory.get_instance(CFG.SYSTEM_APP).create()
|
||||
self.need_rerank = bool(app_config.models.rerankers)
|
||||
self.reranker = RerankEmbeddingsRanker(
|
||||
rerank_embeddings, topk=app_config.rag.rerank_top_k
|
||||
)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@ -77,6 +88,9 @@ class RetrieverResource(Resource[ResourceParameters]):
|
||||
if not question:
|
||||
raise ValueError("Question is required for knowledge resource.")
|
||||
chunks = await self.retrieve(question)
|
||||
if self.need_rerank and len(chunks) > 1:
|
||||
chunks = self.reranker.rank(candidates_with_scores=chunks, query=question)
|
||||
|
||||
content = "\n".join(
|
||||
[f"--{i}--:" + chunk.content for i, chunk in enumerate(chunks)]
|
||||
)
|
||||
@ -97,6 +111,9 @@ class RetrieverResource(Resource[ResourceParameters]):
|
||||
if not question:
|
||||
raise ValueError("Question is required for knowledge resource.")
|
||||
chunks = await self.retrieve(question)
|
||||
if self.need_rerank and len(chunks) > 1:
|
||||
chunks = self.reranker.rank(candidates_with_scores=chunks, query=question)
|
||||
|
||||
prompt_template = """Resources-{name}:\n {content}"""
|
||||
prompt_template_zh = """资源-{name}:\n {content}"""
|
||||
if lang == "en":
|
||||
|
Loading…
Reference in New Issue
Block a user