feat(model): Support siliconflow rerank models (#2188)

This commit is contained in:
Fangyin Cheng
2024-12-11 18:36:44 +08:00
committed by GitHub
parent 23aedea092
commit abab4e3e65
5 changed files with 96 additions and 7 deletions

View File

@@ -109,6 +109,18 @@ class EmbeddingLoader:
if proxy_param.proxy_backend:
openapi_param["model_name"] = proxy_param.proxy_backend
return OpenAPIRerankEmbeddings(**openapi_param)
elif model_name in ["rerank_proxy_silicon_flow"]:
from dbgpt.rag.embedding.rerank import SiliconFlowRerankEmbeddings
proxy_param = cast(ProxyEmbeddingParameters, param)
openapi_param = {}
if proxy_param.proxy_server_url:
openapi_param["api_url"] = proxy_param.proxy_server_url
if proxy_param.proxy_api_key:
openapi_param["api_key"] = proxy_param.proxy_api_key
if proxy_param.proxy_backend:
openapi_param["model_name"] = proxy_param.proxy_backend
return SiliconFlowRerankEmbeddings(**openapi_param)
else:
from dbgpt.rag.embedding.rerank import CrossEncoderRerankEmbeddings

View File

@@ -613,7 +613,16 @@ class ProxyEmbeddingParameters(BaseEmbeddingModelParameters):
_EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG = {
ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi,proxy_ollama,proxy_tongyi,proxy_qianfan,rerank_proxy_http_openapi",
ProxyEmbeddingParameters: [
"proxy_openai",
"proxy_azure",
"proxy_http_openapi",
"proxy_ollama",
"proxy_tongyi",
"proxy_qianfan",
"rerank_proxy_http_openapi",
"rerank_proxy_silicon_flow",
]
}
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG = {}
@@ -622,7 +631,6 @@ EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG = {}
def _update_embedding_config():
global EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG
for param_cls, models in _EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG.items():
models = [m.strip() for m in models.split(",")]
for model in models:
if model not in EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG:
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG[model] = param_cls