feat(model): Support deploy rerank model (#1522)

This commit is contained in:
Fangyin Cheng
2024-05-16 14:50:16 +08:00
committed by GitHub
parent 559affe87d
commit 593e974405
29 changed files with 814 additions and 75 deletions

View File

@@ -227,6 +227,14 @@ class WebServerParameters(BaseParameters):
"text2vec --model_name xxx --model_path xxx`"
},
)
remote_rerank: Optional[bool] = field(
default=False,
metadata={
"help": "Whether to enable remote rerank models. If it is True, you need"
" to start a rerank model through `dbgpt start worker --worker_type "
"text2vec --rerank --model_name xxx --model_path xxx`"
},
)
log_level: Optional[str] = field(
default=None,
metadata={

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import logging
from typing import Optional
from dbgpt._private.config import Config
from dbgpt.app.base import WebServerParameters
@@ -18,9 +19,14 @@ def initialize_components(
system_app: SystemApp,
embedding_model_name: str,
embedding_model_path: str,
rerank_model_name: Optional[str] = None,
rerank_model_path: Optional[str] = None,
):
# Lazy import to avoid high time cost
from dbgpt.app.initialization.embedding_component import _initialize_embedding_model
from dbgpt.app.initialization.embedding_component import (
_initialize_embedding_model,
_initialize_rerank_model,
)
from dbgpt.app.initialization.scheduler import DefaultScheduler
from dbgpt.app.initialization.serve_initialization import register_serve_apps
from dbgpt.datasource.manages.connector_manager import ConnectorManager
@@ -45,6 +51,7 @@ def initialize_components(
_initialize_embedding_model(
param, system_app, embedding_model_name, embedding_model_path
)
_initialize_rerank_model(param, system_app, rerank_model_name, rerank_model_path)
_initialize_model_cache(system_app)
_initialize_awel(system_app, param)
# Initialize resource manager of agent
@@ -89,6 +96,11 @@ def _initialize_agent(system_app: SystemApp):
def _initialize_resource_manager(system_app: SystemApp):
from dbgpt.agent.expand.resources.dbgpt_tool import list_dbgpt_support_models
from dbgpt.agent.expand.resources.host_tool import (
get_current_host_cpu_status,
get_current_host_memory_status,
get_current_host_system_load,
)
from dbgpt.agent.expand.resources.search_tool import baidu_search
from dbgpt.agent.resource.base import ResourceType
from dbgpt.agent.resource.manage import get_resource_manager, initialize_resource
@@ -104,6 +116,10 @@ def _initialize_resource_manager(system_app: SystemApp):
# Register a search tool
rm.register_resource(resource_instance=baidu_search)
rm.register_resource(resource_instance=list_dbgpt_support_models)
# Register host tools
rm.register_resource(resource_instance=get_current_host_cpu_status)
rm.register_resource(resource_instance=get_current_host_memory_status)
rm.register_resource(resource_instance=get_current_host_system_load)
def _initialize_openapi(system_app: SystemApp):

View File

@@ -146,11 +146,24 @@ def initialize_app(param: WebServerParameters = None, args: List[str] = None):
embedding_model_name = CFG.EMBEDDING_MODEL
embedding_model_path = EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
rerank_model_name = CFG.RERANK_MODEL
rerank_model_path = None
if rerank_model_name:
rerank_model_path = CFG.RERANK_MODEL_PATH or EMBEDDING_MODEL_CONFIG.get(
rerank_model_name
)
server_init(param, system_app)
mount_routers(app)
model_start_listener = _create_model_start_listener(system_app)
initialize_components(param, system_app, embedding_model_name, embedding_model_path)
initialize_components(
param,
system_app,
embedding_model_name,
embedding_model_path,
rerank_model_name,
rerank_model_path,
)
system_app.on_init()
# Migration db storage, so you db models must be imported before this
@@ -161,7 +174,13 @@ def initialize_app(param: WebServerParameters = None, args: List[str] = None):
if not param.light:
print("Model Unified Deployment Mode!")
if not param.remote_embedding:
# Embedding model is running in the same process, set embedding_model_name
# and embedding_model_path to None
embedding_model_name, embedding_model_path = None, None
if not param.remote_rerank:
# Rerank model is running in the same process, set rerank_model_name and
# rerank_model_path to None
rerank_model_name, rerank_model_path = None, None
initialize_worker_manager_in_client(
app=app,
model_name=model_name,
@@ -169,6 +188,8 @@ def initialize_app(param: WebServerParameters = None, args: List[str] = None):
local_port=param.port,
embedding_model_name=embedding_model_name,
embedding_model_path=embedding_model_path,
rerank_model_name=rerank_model_name,
rerank_model_path=rerank_model_path,
start_listener=model_start_listener,
system_app=system_app,
)

View File

@@ -1,10 +1,14 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, Type
from typing import TYPE_CHECKING, Any, Optional, Type
from dbgpt.component import ComponentType, SystemApp
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
from dbgpt.core import Embeddings, RerankEmbeddings
from dbgpt.rag.embedding.embedding_factory import (
EmbeddingFactory,
RerankEmbeddingFactory,
)
if TYPE_CHECKING:
from langchain.embeddings.base import Embeddings
@@ -32,6 +36,26 @@ def _initialize_embedding_model(
)
def _initialize_rerank_model(
param: "WebServerParameters",
system_app: SystemApp,
rerank_model_name: Optional[str] = None,
rerank_model_path: Optional[str] = None,
):
if not rerank_model_name:
return
if param.remote_rerank:
logger.info("Register remote RemoteRerankEmbeddingFactory")
system_app.register(RemoteRerankEmbeddingFactory, model_name=rerank_model_name)
else:
logger.info(f"Register local LocalRerankEmbeddingFactory")
system_app.register(
LocalRerankEmbeddingFactory,
default_model_name=rerank_model_name,
default_model_path=rerank_model_path,
)
class RemoteEmbeddingFactory(EmbeddingFactory):
def __init__(self, system_app, model_name: str = None, **kwargs: Any) -> None:
super().__init__(system_app=system_app)
@@ -105,3 +129,81 @@ class LocalEmbeddingFactory(EmbeddingFactory):
loader = EmbeddingLoader()
# Ignore model_name args
return loader.load(self._default_model_name, model_params)
class RemoteRerankEmbeddingFactory(RerankEmbeddingFactory):
def __init__(self, system_app, model_name: str = None, **kwargs: Any) -> None:
super().__init__(system_app=system_app)
self._default_model_name = model_name
self.kwargs = kwargs
self.system_app = system_app
def init_app(self, system_app):
self.system_app = system_app
def create(
self, model_name: str = None, embedding_cls: Type = None
) -> "RerankEmbeddings":
from dbgpt.model.cluster import WorkerManagerFactory
from dbgpt.model.cluster.embedding.remote_embedding import (
RemoteRerankEmbeddings,
)
if embedding_cls:
raise NotImplementedError
worker_manager = self.system_app.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
return RemoteRerankEmbeddings(
model_name or self._default_model_name, worker_manager
)
class LocalRerankEmbeddingFactory(RerankEmbeddingFactory):
def __init__(
self,
system_app,
default_model_name: str = None,
default_model_path: str = None,
**kwargs: Any,
) -> None:
super().__init__(system_app=system_app)
self._default_model_name = default_model_name
self._default_model_path = default_model_path
self._kwargs = kwargs
self._model = self._load_model()
def init_app(self, system_app):
pass
def create(
self, model_name: str = None, embedding_cls: Type = None
) -> "RerankEmbeddings":
if embedding_cls:
raise NotImplementedError
return self._model
def _load_model(self) -> "RerankEmbeddings":
from dbgpt.model.adapter.embeddings_loader import (
EmbeddingLoader,
_parse_embedding_params,
)
from dbgpt.model.parameter import (
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
BaseEmbeddingModelParameters,
EmbeddingModelParameters,
)
param_cls = EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
self._default_model_name, EmbeddingModelParameters
)
model_params: BaseEmbeddingModelParameters = _parse_embedding_params(
model_name=self._default_model_name,
model_path=self._default_model_path,
param_cls=param_cls,
**self._kwargs,
)
logger.info(model_params)
loader = EmbeddingLoader()
# Ignore model_name args
return loader.load_rerank_model(self._default_model_name, model_params)

View File

@@ -18,6 +18,7 @@ from dbgpt.core import (
MessagesPlaceholder,
SystemPromptTemplate,
)
from dbgpt.rag.retriever.rerank import RerankEmbeddingsRanker
from dbgpt.rag.retriever.rewrite import QueryRewrite
from dbgpt.util.tracer import root_tracer, trace
@@ -37,7 +38,10 @@ class ChatKnowledge(BaseChat):
- model_name:(str) llm model name
- select_param:(str) space name
"""
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
from dbgpt.rag.embedding.embedding_factory import (
EmbeddingFactory,
RerankEmbeddingFactory,
)
self.knowledge_space = chat_param["select_param"]
chat_param["chat_mode"] = ChatScene.ChatKnowledge
@@ -83,10 +87,22 @@ class ChatKnowledge(BaseChat):
model_name=self.llm_model,
language=CFG.LANGUAGE,
)
reranker = None
retriever_top_k = self.top_k
if CFG.RERANK_MODEL:
rerank_embeddings = RerankEmbeddingFactory.get_instance(
CFG.SYSTEM_APP
).create()
reranker = RerankEmbeddingsRanker(rerank_embeddings, topk=CFG.RERANK_TOP_K)
if retriever_top_k < CFG.RERANK_TOP_K or retriever_top_k < 20:
# We use reranker, so if the top_k is less than 20,
# we need to set it to 20
retriever_top_k = max(CFG.RERANK_TOP_K, 20)
self.embedding_retriever = EmbeddingRetriever(
top_k=self.top_k,
top_k=retriever_top_k,
vector_store_connector=vector_store_connector,
query_rewrite=query_rewrite,
rerank=reranker,
)
self.prompt_template.template_is_strict = False
self.relations = None