mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-07 20:10:08 +00:00
feat(model): Support deploy rerank model (#1522)
This commit is contained in:
@@ -227,6 +227,14 @@ class WebServerParameters(BaseParameters):
|
||||
"text2vec --model_name xxx --model_path xxx`"
|
||||
},
|
||||
)
|
||||
remote_rerank: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to enable remote rerank models. If it is True, you need"
|
||||
" to start a rerank model through `dbgpt start worker --worker_type "
|
||||
"text2vec --rerank --model_name xxx --model_path xxx`"
|
||||
},
|
||||
)
|
||||
log_level: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
|
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.app.base import WebServerParameters
|
||||
@@ -18,9 +19,14 @@ def initialize_components(
|
||||
system_app: SystemApp,
|
||||
embedding_model_name: str,
|
||||
embedding_model_path: str,
|
||||
rerank_model_name: Optional[str] = None,
|
||||
rerank_model_path: Optional[str] = None,
|
||||
):
|
||||
# Lazy import to avoid high time cost
|
||||
from dbgpt.app.initialization.embedding_component import _initialize_embedding_model
|
||||
from dbgpt.app.initialization.embedding_component import (
|
||||
_initialize_embedding_model,
|
||||
_initialize_rerank_model,
|
||||
)
|
||||
from dbgpt.app.initialization.scheduler import DefaultScheduler
|
||||
from dbgpt.app.initialization.serve_initialization import register_serve_apps
|
||||
from dbgpt.datasource.manages.connector_manager import ConnectorManager
|
||||
@@ -45,6 +51,7 @@ def initialize_components(
|
||||
_initialize_embedding_model(
|
||||
param, system_app, embedding_model_name, embedding_model_path
|
||||
)
|
||||
_initialize_rerank_model(param, system_app, rerank_model_name, rerank_model_path)
|
||||
_initialize_model_cache(system_app)
|
||||
_initialize_awel(system_app, param)
|
||||
# Initialize resource manager of agent
|
||||
@@ -89,6 +96,11 @@ def _initialize_agent(system_app: SystemApp):
|
||||
|
||||
def _initialize_resource_manager(system_app: SystemApp):
|
||||
from dbgpt.agent.expand.resources.dbgpt_tool import list_dbgpt_support_models
|
||||
from dbgpt.agent.expand.resources.host_tool import (
|
||||
get_current_host_cpu_status,
|
||||
get_current_host_memory_status,
|
||||
get_current_host_system_load,
|
||||
)
|
||||
from dbgpt.agent.expand.resources.search_tool import baidu_search
|
||||
from dbgpt.agent.resource.base import ResourceType
|
||||
from dbgpt.agent.resource.manage import get_resource_manager, initialize_resource
|
||||
@@ -104,6 +116,10 @@ def _initialize_resource_manager(system_app: SystemApp):
|
||||
# Register a search tool
|
||||
rm.register_resource(resource_instance=baidu_search)
|
||||
rm.register_resource(resource_instance=list_dbgpt_support_models)
|
||||
# Register host tools
|
||||
rm.register_resource(resource_instance=get_current_host_cpu_status)
|
||||
rm.register_resource(resource_instance=get_current_host_memory_status)
|
||||
rm.register_resource(resource_instance=get_current_host_system_load)
|
||||
|
||||
|
||||
def _initialize_openapi(system_app: SystemApp):
|
||||
|
@@ -146,11 +146,24 @@ def initialize_app(param: WebServerParameters = None, args: List[str] = None):
|
||||
|
||||
embedding_model_name = CFG.EMBEDDING_MODEL
|
||||
embedding_model_path = EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
|
||||
rerank_model_name = CFG.RERANK_MODEL
|
||||
rerank_model_path = None
|
||||
if rerank_model_name:
|
||||
rerank_model_path = CFG.RERANK_MODEL_PATH or EMBEDDING_MODEL_CONFIG.get(
|
||||
rerank_model_name
|
||||
)
|
||||
|
||||
server_init(param, system_app)
|
||||
mount_routers(app)
|
||||
model_start_listener = _create_model_start_listener(system_app)
|
||||
initialize_components(param, system_app, embedding_model_name, embedding_model_path)
|
||||
initialize_components(
|
||||
param,
|
||||
system_app,
|
||||
embedding_model_name,
|
||||
embedding_model_path,
|
||||
rerank_model_name,
|
||||
rerank_model_path,
|
||||
)
|
||||
system_app.on_init()
|
||||
|
||||
# Migration db storage, so you db models must be imported before this
|
||||
@@ -161,7 +174,13 @@ def initialize_app(param: WebServerParameters = None, args: List[str] = None):
|
||||
if not param.light:
|
||||
print("Model Unified Deployment Mode!")
|
||||
if not param.remote_embedding:
|
||||
# Embedding model is running in the same process, set embedding_model_name
|
||||
# and embedding_model_path to None
|
||||
embedding_model_name, embedding_model_path = None, None
|
||||
if not param.remote_rerank:
|
||||
# Rerank model is running in the same process, set rerank_model_name and
|
||||
# rerank_model_path to None
|
||||
rerank_model_name, rerank_model_path = None, None
|
||||
initialize_worker_manager_in_client(
|
||||
app=app,
|
||||
model_name=model_name,
|
||||
@@ -169,6 +188,8 @@ def initialize_app(param: WebServerParameters = None, args: List[str] = None):
|
||||
local_port=param.port,
|
||||
embedding_model_name=embedding_model_name,
|
||||
embedding_model_path=embedding_model_path,
|
||||
rerank_model_name=rerank_model_name,
|
||||
rerank_model_path=rerank_model_path,
|
||||
start_listener=model_start_listener,
|
||||
system_app=system_app,
|
||||
)
|
||||
|
@@ -1,10 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Type
|
||||
from typing import TYPE_CHECKING, Any, Optional, Type
|
||||
|
||||
from dbgpt.component import ComponentType, SystemApp
|
||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||
from dbgpt.core import Embeddings, RerankEmbeddings
|
||||
from dbgpt.rag.embedding.embedding_factory import (
|
||||
EmbeddingFactory,
|
||||
RerankEmbeddingFactory,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.embeddings.base import Embeddings
|
||||
@@ -32,6 +36,26 @@ def _initialize_embedding_model(
|
||||
)
|
||||
|
||||
|
||||
def _initialize_rerank_model(
|
||||
param: "WebServerParameters",
|
||||
system_app: SystemApp,
|
||||
rerank_model_name: Optional[str] = None,
|
||||
rerank_model_path: Optional[str] = None,
|
||||
):
|
||||
if not rerank_model_name:
|
||||
return
|
||||
if param.remote_rerank:
|
||||
logger.info("Register remote RemoteRerankEmbeddingFactory")
|
||||
system_app.register(RemoteRerankEmbeddingFactory, model_name=rerank_model_name)
|
||||
else:
|
||||
logger.info(f"Register local LocalRerankEmbeddingFactory")
|
||||
system_app.register(
|
||||
LocalRerankEmbeddingFactory,
|
||||
default_model_name=rerank_model_name,
|
||||
default_model_path=rerank_model_path,
|
||||
)
|
||||
|
||||
|
||||
class RemoteEmbeddingFactory(EmbeddingFactory):
|
||||
def __init__(self, system_app, model_name: str = None, **kwargs: Any) -> None:
|
||||
super().__init__(system_app=system_app)
|
||||
@@ -105,3 +129,81 @@ class LocalEmbeddingFactory(EmbeddingFactory):
|
||||
loader = EmbeddingLoader()
|
||||
# Ignore model_name args
|
||||
return loader.load(self._default_model_name, model_params)
|
||||
|
||||
|
||||
class RemoteRerankEmbeddingFactory(RerankEmbeddingFactory):
|
||||
def __init__(self, system_app, model_name: str = None, **kwargs: Any) -> None:
|
||||
super().__init__(system_app=system_app)
|
||||
self._default_model_name = model_name
|
||||
self.kwargs = kwargs
|
||||
self.system_app = system_app
|
||||
|
||||
def init_app(self, system_app):
|
||||
self.system_app = system_app
|
||||
|
||||
def create(
|
||||
self, model_name: str = None, embedding_cls: Type = None
|
||||
) -> "RerankEmbeddings":
|
||||
from dbgpt.model.cluster import WorkerManagerFactory
|
||||
from dbgpt.model.cluster.embedding.remote_embedding import (
|
||||
RemoteRerankEmbeddings,
|
||||
)
|
||||
|
||||
if embedding_cls:
|
||||
raise NotImplementedError
|
||||
worker_manager = self.system_app.get_component(
|
||||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||
).create()
|
||||
return RemoteRerankEmbeddings(
|
||||
model_name or self._default_model_name, worker_manager
|
||||
)
|
||||
|
||||
|
||||
class LocalRerankEmbeddingFactory(RerankEmbeddingFactory):
|
||||
def __init__(
|
||||
self,
|
||||
system_app,
|
||||
default_model_name: str = None,
|
||||
default_model_path: str = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(system_app=system_app)
|
||||
self._default_model_name = default_model_name
|
||||
self._default_model_path = default_model_path
|
||||
self._kwargs = kwargs
|
||||
self._model = self._load_model()
|
||||
|
||||
def init_app(self, system_app):
|
||||
pass
|
||||
|
||||
def create(
|
||||
self, model_name: str = None, embedding_cls: Type = None
|
||||
) -> "RerankEmbeddings":
|
||||
if embedding_cls:
|
||||
raise NotImplementedError
|
||||
return self._model
|
||||
|
||||
def _load_model(self) -> "RerankEmbeddings":
|
||||
from dbgpt.model.adapter.embeddings_loader import (
|
||||
EmbeddingLoader,
|
||||
_parse_embedding_params,
|
||||
)
|
||||
from dbgpt.model.parameter import (
|
||||
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
|
||||
BaseEmbeddingModelParameters,
|
||||
EmbeddingModelParameters,
|
||||
)
|
||||
|
||||
param_cls = EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
|
||||
self._default_model_name, EmbeddingModelParameters
|
||||
)
|
||||
model_params: BaseEmbeddingModelParameters = _parse_embedding_params(
|
||||
model_name=self._default_model_name,
|
||||
model_path=self._default_model_path,
|
||||
param_cls=param_cls,
|
||||
**self._kwargs,
|
||||
)
|
||||
logger.info(model_params)
|
||||
loader = EmbeddingLoader()
|
||||
# Ignore model_name args
|
||||
return loader.load_rerank_model(self._default_model_name, model_params)
|
||||
|
@@ -18,6 +18,7 @@ from dbgpt.core import (
|
||||
MessagesPlaceholder,
|
||||
SystemPromptTemplate,
|
||||
)
|
||||
from dbgpt.rag.retriever.rerank import RerankEmbeddingsRanker
|
||||
from dbgpt.rag.retriever.rewrite import QueryRewrite
|
||||
from dbgpt.util.tracer import root_tracer, trace
|
||||
|
||||
@@ -37,7 +38,10 @@ class ChatKnowledge(BaseChat):
|
||||
- model_name:(str) llm model name
|
||||
- select_param:(str) space name
|
||||
"""
|
||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||
from dbgpt.rag.embedding.embedding_factory import (
|
||||
EmbeddingFactory,
|
||||
RerankEmbeddingFactory,
|
||||
)
|
||||
|
||||
self.knowledge_space = chat_param["select_param"]
|
||||
chat_param["chat_mode"] = ChatScene.ChatKnowledge
|
||||
@@ -83,10 +87,22 @@ class ChatKnowledge(BaseChat):
|
||||
model_name=self.llm_model,
|
||||
language=CFG.LANGUAGE,
|
||||
)
|
||||
reranker = None
|
||||
retriever_top_k = self.top_k
|
||||
if CFG.RERANK_MODEL:
|
||||
rerank_embeddings = RerankEmbeddingFactory.get_instance(
|
||||
CFG.SYSTEM_APP
|
||||
).create()
|
||||
reranker = RerankEmbeddingsRanker(rerank_embeddings, topk=CFG.RERANK_TOP_K)
|
||||
if retriever_top_k < CFG.RERANK_TOP_K or retriever_top_k < 20:
|
||||
# We use reranker, so if the top_k is less than 20,
|
||||
# we need to set it to 20
|
||||
retriever_top_k = max(CFG.RERANK_TOP_K, 20)
|
||||
self.embedding_retriever = EmbeddingRetriever(
|
||||
top_k=self.top_k,
|
||||
top_k=retriever_top_k,
|
||||
vector_store_connector=vector_store_connector,
|
||||
query_rewrite=query_rewrite,
|
||||
rerank=reranker,
|
||||
)
|
||||
self.prompt_template.template_is_strict = False
|
||||
self.relations = None
|
||||
|
Reference in New Issue
Block a user