feat(model): Support deploy rerank model (#1522)

This commit is contained in:
Fangyin Cheng
2024-05-16 14:50:16 +08:00
committed by GitHub
parent 559affe87d
commit 593e974405
29 changed files with 814 additions and 75 deletions

View File

@@ -4,7 +4,7 @@ import logging
from typing import List, Optional, Type, cast
from dbgpt.configs.model_config import get_device
from dbgpt.core import Embeddings
from dbgpt.core import Embeddings, RerankEmbeddings
from dbgpt.model.parameter import (
BaseEmbeddingModelParameters,
EmbeddingModelParameters,
@@ -66,6 +66,38 @@ class EmbeddingLoader:
kwargs = param.build_kwargs(model_name=param.model_path)
return HuggingFaceEmbeddings(**kwargs)
def load_rerank_model(
self, model_name: str, param: BaseEmbeddingModelParameters
) -> RerankEmbeddings:
metadata = {
"model_name": model_name,
"run_service": SpanTypeRunName.EMBEDDING_MODEL.value,
"params": _get_dict_from_obj(param),
"sys_infos": _get_dict_from_obj(get_system_info()),
}
with root_tracer.start_span(
"EmbeddingLoader.load_rerank_model",
span_type=SpanType.RUN,
metadata=metadata,
):
if model_name in ["rerank_proxy_http_openapi"]:
from dbgpt.rag.embedding.rerank import OpenAPIRerankEmbeddings
proxy_param = cast(ProxyEmbeddingParameters, param)
openapi_param = {}
if proxy_param.proxy_server_url:
openapi_param["api_url"] = proxy_param.proxy_server_url
if proxy_param.proxy_api_key:
openapi_param["api_key"] = proxy_param.proxy_api_key
if proxy_param.proxy_backend:
openapi_param["model_name"] = proxy_param.proxy_backend
return OpenAPIRerankEmbeddings(**openapi_param)
else:
from dbgpt.rag.embedding.rerank import CrossEncoderRerankEmbeddings
kwargs = param.build_kwargs(model_name=param.model_path)
return CrossEncoderRerankEmbeddings(**kwargs)
def _parse_embedding_params(
model_name: Optional[str] = None,