mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-20 00:54:43 +00:00
feat(model): Support deploy rerank model (#1522)
This commit is contained in:
@@ -18,6 +18,7 @@ from dbgpt.core import (
|
||||
MessagesPlaceholder,
|
||||
SystemPromptTemplate,
|
||||
)
|
||||
from dbgpt.rag.retriever.rerank import RerankEmbeddingsRanker
|
||||
from dbgpt.rag.retriever.rewrite import QueryRewrite
|
||||
from dbgpt.util.tracer import root_tracer, trace
|
||||
|
||||
@@ -37,7 +38,10 @@ class ChatKnowledge(BaseChat):
|
||||
- model_name:(str) llm model name
|
||||
- select_param:(str) space name
|
||||
"""
|
||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||
from dbgpt.rag.embedding.embedding_factory import (
|
||||
EmbeddingFactory,
|
||||
RerankEmbeddingFactory,
|
||||
)
|
||||
|
||||
self.knowledge_space = chat_param["select_param"]
|
||||
chat_param["chat_mode"] = ChatScene.ChatKnowledge
|
||||
@@ -83,10 +87,22 @@ class ChatKnowledge(BaseChat):
|
||||
model_name=self.llm_model,
|
||||
language=CFG.LANGUAGE,
|
||||
)
|
||||
reranker = None
|
||||
retriever_top_k = self.top_k
|
||||
if CFG.RERANK_MODEL:
|
||||
rerank_embeddings = RerankEmbeddingFactory.get_instance(
|
||||
CFG.SYSTEM_APP
|
||||
).create()
|
||||
reranker = RerankEmbeddingsRanker(rerank_embeddings, topk=CFG.RERANK_TOP_K)
|
||||
if retriever_top_k < CFG.RERANK_TOP_K or retriever_top_k < 20:
|
||||
# We use reranker, so if the top_k is less than 20,
|
||||
# we need to set it to 20
|
||||
retriever_top_k = max(CFG.RERANK_TOP_K, 20)
|
||||
self.embedding_retriever = EmbeddingRetriever(
|
||||
top_k=self.top_k,
|
||||
top_k=retriever_top_k,
|
||||
vector_store_connector=vector_store_connector,
|
||||
query_rewrite=query_rewrite,
|
||||
rerank=reranker,
|
||||
)
|
||||
self.prompt_template.template_is_strict = False
|
||||
self.relations = None
|
||||
|
Reference in New Issue
Block a user