mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-05 19:11:52 +00:00
feat(model): Support deploy rerank model (#1522)
This commit is contained in:
@@ -4,7 +4,7 @@ import logging
|
||||
from typing import List, Optional, Type, cast
|
||||
|
||||
from dbgpt.configs.model_config import get_device
|
||||
from dbgpt.core import Embeddings
|
||||
from dbgpt.core import Embeddings, RerankEmbeddings
|
||||
from dbgpt.model.parameter import (
|
||||
BaseEmbeddingModelParameters,
|
||||
EmbeddingModelParameters,
|
||||
@@ -66,6 +66,38 @@ class EmbeddingLoader:
|
||||
kwargs = param.build_kwargs(model_name=param.model_path)
|
||||
return HuggingFaceEmbeddings(**kwargs)
|
||||
|
||||
def load_rerank_model(
|
||||
self, model_name: str, param: BaseEmbeddingModelParameters
|
||||
) -> RerankEmbeddings:
|
||||
metadata = {
|
||||
"model_name": model_name,
|
||||
"run_service": SpanTypeRunName.EMBEDDING_MODEL.value,
|
||||
"params": _get_dict_from_obj(param),
|
||||
"sys_infos": _get_dict_from_obj(get_system_info()),
|
||||
}
|
||||
with root_tracer.start_span(
|
||||
"EmbeddingLoader.load_rerank_model",
|
||||
span_type=SpanType.RUN,
|
||||
metadata=metadata,
|
||||
):
|
||||
if model_name in ["rerank_proxy_http_openapi"]:
|
||||
from dbgpt.rag.embedding.rerank import OpenAPIRerankEmbeddings
|
||||
|
||||
proxy_param = cast(ProxyEmbeddingParameters, param)
|
||||
openapi_param = {}
|
||||
if proxy_param.proxy_server_url:
|
||||
openapi_param["api_url"] = proxy_param.proxy_server_url
|
||||
if proxy_param.proxy_api_key:
|
||||
openapi_param["api_key"] = proxy_param.proxy_api_key
|
||||
if proxy_param.proxy_backend:
|
||||
openapi_param["model_name"] = proxy_param.proxy_backend
|
||||
return OpenAPIRerankEmbeddings(**openapi_param)
|
||||
else:
|
||||
from dbgpt.rag.embedding.rerank import CrossEncoderRerankEmbeddings
|
||||
|
||||
kwargs = param.build_kwargs(model_name=param.model_path)
|
||||
return CrossEncoderRerankEmbeddings(**kwargs)
|
||||
|
||||
|
||||
def _parse_embedding_params(
|
||||
model_name: Optional[str] = None,
|
||||
|
Reference in New Issue
Block a user