feat(model): Support deploy rerank model (#1522)

This commit is contained in:
Fangyin Cheng
2024-05-16 14:50:16 +08:00
committed by GitHub
parent 559affe87d
commit 593e974405
29 changed files with 814 additions and 75 deletions

View File

@@ -4,7 +4,7 @@ import logging
from typing import List, Optional, Type, cast
from dbgpt.configs.model_config import get_device
from dbgpt.core import Embeddings
from dbgpt.core import Embeddings, RerankEmbeddings
from dbgpt.model.parameter import (
BaseEmbeddingModelParameters,
EmbeddingModelParameters,
@@ -66,6 +66,38 @@ class EmbeddingLoader:
kwargs = param.build_kwargs(model_name=param.model_path)
return HuggingFaceEmbeddings(**kwargs)
def load_rerank_model(
self, model_name: str, param: BaseEmbeddingModelParameters
) -> RerankEmbeddings:
metadata = {
"model_name": model_name,
"run_service": SpanTypeRunName.EMBEDDING_MODEL.value,
"params": _get_dict_from_obj(param),
"sys_infos": _get_dict_from_obj(get_system_info()),
}
with root_tracer.start_span(
"EmbeddingLoader.load_rerank_model",
span_type=SpanType.RUN,
metadata=metadata,
):
if model_name in ["rerank_proxy_http_openapi"]:
from dbgpt.rag.embedding.rerank import OpenAPIRerankEmbeddings
proxy_param = cast(ProxyEmbeddingParameters, param)
openapi_param = {}
if proxy_param.proxy_server_url:
openapi_param["api_url"] = proxy_param.proxy_server_url
if proxy_param.proxy_api_key:
openapi_param["api_key"] = proxy_param.proxy_api_key
if proxy_param.proxy_backend:
openapi_param["model_name"] = proxy_param.proxy_backend
return OpenAPIRerankEmbeddings(**openapi_param)
else:
from dbgpt.rag.embedding.rerank import CrossEncoderRerankEmbeddings
kwargs = param.build_kwargs(model_name=param.model_path)
return CrossEncoderRerankEmbeddings(**kwargs)
def _parse_embedding_params(
model_name: Optional[str] = None,

View File

@@ -147,15 +147,14 @@ def _dynamic_model_parser() -> Optional[List[Type[BaseModelParameters]]]:
model_path = pre_args.get("model_path")
worker_type = pre_args.get("worker_type")
model_type = pre_args.get("model_type")
if model_name is None and model_type != ModelType.VLLM:
return None
if worker_type == WorkerType.TEXT2VEC:
return [
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
model_name, EmbeddingModelParameters
)
]
if model_name is None and model_type != ModelType.VLLM:
return None
llm_adapter = get_llm_model_adapter(model_name, model_path, model_type=model_type)
param_class = llm_adapter.model_param_class()
return [param_class]