Merge branch 'refs/heads/main' into feat/agent_optimize

This commit is contained in:
yhjun1026
2025-04-21 11:41:13 +08:00
2 changed files with 25 additions and 4 deletions

View File

@@ -340,11 +340,14 @@ class KnowledgeService:
else 0.3
)
if CFG.RERANK_MODEL is not None:
if top_k < int(CFG.RERANK_TOP_K) or top_k < 20:
app_config = CFG.SYSTEM_APP.config.configs.get("app_config")
rerank_top_k = app_config.rag.rerank_top_k
if app_config.models.rerankers:
if top_k < int(rerank_top_k) or top_k < 20:
# We use reranker, so if the top_k is less than 20,
# we need to set it to 20
top_k = max(int(CFG.RERANK_TOP_K), 20)
top_k = max(int(rerank_top_k), 20)
knowledge_space_retriever = KnowledgeSpaceRetriever(
space_id=space.id, top_k=top_k, system_app=CFG.SYSTEM_APP
@@ -360,7 +363,8 @@ class KnowledgeService:
)
recall_top_k = int(doc_recall_test_request.recall_top_k)
if CFG.RERANK_MODEL is not None:
if app_config.models.rerankers:
rerank_embeddings = RerankEmbeddingFactory.get_instance(
CFG.SYSTEM_APP
).create()

View File

@@ -5,7 +5,10 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
import cachetools
from dbgpt._private.config import Config
from dbgpt.core import Chunk
from dbgpt.rag.embedding.embedding_factory import RerankEmbeddingFactory
from dbgpt.rag.retriever.rerank import RerankEmbeddingsRanker
from dbgpt.util.cache_utils import cached
from .base import Resource, ResourceParameters, ResourceType
@@ -14,6 +17,8 @@ if TYPE_CHECKING:
from dbgpt.rag.retriever.base import BaseRetriever
from dbgpt.storage.vector_store.filters import MetadataFilters
CFG = Config()
@dataclasses.dataclass
class RetrieverResourceParameters(ResourceParameters):
@@ -32,6 +37,12 @@ class RetrieverResource(Resource[ResourceParameters]):
"""Create a new RetrieverResource."""
self._name = name
self._retriever = retriever
app_config = CFG.SYSTEM_APP.config.configs.get("app_config")
rerank_embeddings = RerankEmbeddingFactory.get_instance(CFG.SYSTEM_APP).create()
self.need_rerank = bool(app_config.models.rerankers)
self.reranker = RerankEmbeddingsRanker(
rerank_embeddings, topk=app_config.rag.rerank_top_k
)
@property
def name(self) -> str:
@@ -77,6 +88,9 @@ class RetrieverResource(Resource[ResourceParameters]):
if not question:
raise ValueError("Question is required for knowledge resource.")
chunks = await self.retrieve(question)
if self.need_rerank and len(chunks) > 1:
chunks = self.reranker.rank(candidates_with_scores=chunks, query=question)
content = "\n".join(
[f"--{i}--:" + chunk.content for i, chunk in enumerate(chunks)]
)
@@ -97,6 +111,9 @@ class RetrieverResource(Resource[ResourceParameters]):
if not question:
raise ValueError("Question is required for knowledge resource.")
chunks = await self.retrieve(question)
if self.need_rerank and len(chunks) > 1:
chunks = self.reranker.rank(candidates_with_scores=chunks, query=question)
prompt_template = """Resources-{name}:\n {content}"""
prompt_template_zh = """资源-{name}:\n {content}"""
if lang == "en":