feat(model): Support deploy rerank model (#1522)

This commit is contained in:
Fangyin Cheng
2024-05-16 14:50:16 +08:00
committed by GitHub
parent 559affe87d
commit 593e974405
29 changed files with 814 additions and 75 deletions

View File

@@ -18,6 +18,7 @@ from dbgpt.core import (
MessagesPlaceholder,
SystemPromptTemplate,
)
from dbgpt.rag.retriever.rerank import RerankEmbeddingsRanker
from dbgpt.rag.retriever.rewrite import QueryRewrite
from dbgpt.util.tracer import root_tracer, trace
@@ -37,7 +38,10 @@ class ChatKnowledge(BaseChat):
- model_name:(str) llm model name
- select_param:(str) space name
"""
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
from dbgpt.rag.embedding.embedding_factory import (
EmbeddingFactory,
RerankEmbeddingFactory,
)
self.knowledge_space = chat_param["select_param"]
chat_param["chat_mode"] = ChatScene.ChatKnowledge
@@ -83,10 +87,22 @@ class ChatKnowledge(BaseChat):
model_name=self.llm_model,
language=CFG.LANGUAGE,
)
reranker = None
retriever_top_k = self.top_k
if CFG.RERANK_MODEL:
rerank_embeddings = RerankEmbeddingFactory.get_instance(
CFG.SYSTEM_APP
).create()
reranker = RerankEmbeddingsRanker(rerank_embeddings, topk=CFG.RERANK_TOP_K)
if retriever_top_k < CFG.RERANK_TOP_K or retriever_top_k < 20:
# We use reranker, so if the top_k is less than 20,
# we need to set it to 20
retriever_top_k = max(CFG.RERANK_TOP_K, 20)
self.embedding_retriever = EmbeddingRetriever(
top_k=self.top_k,
top_k=retriever_top_k,
vector_store_connector=vector_store_connector,
query_rewrite=query_rewrite,
rerank=reranker,
)
self.prompt_template.template_is_strict = False
self.relations = None