feat(model): Support siliconflow rerank models (#2188)

This commit is contained in:
Fangyin Cheng 2024-12-11 18:36:44 +08:00 committed by GitHub
parent 23aedea092
commit abab4e3e65
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 96 additions and 7 deletions

View File

@ -322,6 +322,7 @@ EMBEDDING_MODEL_CONFIG = {
"bge-reranker-large": os.path.join(MODEL_PATH, "bge-reranker-large"),
# Proxy rerank model
"rerank_proxy_http_openapi": "rerank_proxy_http_openapi",
"rerank_proxy_silicon_flow": "rerank_proxy_silicon_flow",
}

View File

@ -109,6 +109,18 @@ class EmbeddingLoader:
if proxy_param.proxy_backend:
openapi_param["model_name"] = proxy_param.proxy_backend
return OpenAPIRerankEmbeddings(**openapi_param)
elif model_name in ["rerank_proxy_silicon_flow"]:
from dbgpt.rag.embedding.rerank import SiliconFlowRerankEmbeddings
proxy_param = cast(ProxyEmbeddingParameters, param)
openapi_param = {}
if proxy_param.proxy_server_url:
openapi_param["api_url"] = proxy_param.proxy_server_url
if proxy_param.proxy_api_key:
openapi_param["api_key"] = proxy_param.proxy_api_key
if proxy_param.proxy_backend:
openapi_param["model_name"] = proxy_param.proxy_backend
return SiliconFlowRerankEmbeddings(**openapi_param)
else:
from dbgpt.rag.embedding.rerank import CrossEncoderRerankEmbeddings

View File

@ -613,7 +613,16 @@ class ProxyEmbeddingParameters(BaseEmbeddingModelParameters):
_EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG = {
ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi,proxy_ollama,proxy_tongyi,proxy_qianfan,rerank_proxy_http_openapi",
ProxyEmbeddingParameters: [
"proxy_openai",
"proxy_azure",
"proxy_http_openapi",
"proxy_ollama",
"proxy_tongyi",
"proxy_qianfan",
"rerank_proxy_http_openapi",
"rerank_proxy_silicon_flow",
]
}
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG = {}
@ -622,7 +631,6 @@ EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG = {}
def _update_embedding_config():
global EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG
for param_cls, models in _EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG.items():
models = [m.strip() for m in models.split(",")]
for model in models:
if model not in EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG:
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG[model] = param_cls

View File

@ -17,7 +17,11 @@ from .embeddings import ( # noqa: F401
QianFanEmbeddings,
TongYiEmbeddings,
)
from .rerank import CrossEncoderRerankEmbeddings, OpenAPIRerankEmbeddings # noqa: F401
from .rerank import ( # noqa: F401
CrossEncoderRerankEmbeddings,
OpenAPIRerankEmbeddings,
SiliconFlowRerankEmbeddings,
)
__ALL__ = [
"CrossEncoderRerankEmbeddings",
@ -32,6 +36,7 @@ __ALL__ = [
"OllamaEmbeddings",
"OpenAPIEmbeddings",
"OpenAPIRerankEmbeddings",
"SiliconFlowRerankEmbeddings",
"QianFanEmbeddings",
"TongYiEmbeddings",
"WrappedEmbeddingFactory",

View File

@ -1,5 +1,6 @@
"""Re-rank embeddings."""
import os
from typing import Any, Dict, List, Optional, cast
import aiohttp
@ -104,6 +105,24 @@ class OpenAPIRerankEmbeddings(BaseModel, RerankEmbeddings):
kwargs["session"] = session
super().__init__(**kwargs)
def _parse_results(self, response: Dict[str, Any]) -> List[float]:
"""Parse the response from the API.
Args:
response: The response from the API.
Returns:
List[float]: The rank scores of the candidates.
"""
data = response.get("data")
if not data:
if "detail" in response:
raise RuntimeError(response["detail"])
raise RuntimeError("Cannot find results in the response")
if not isinstance(data, list):
raise RuntimeError("Results should be a list")
return data
def predict(self, query: str, candidates: List[str]) -> List[float]:
"""Predict the rank scores of the candidates.
@ -126,7 +145,7 @@ class OpenAPIRerankEmbeddings(BaseModel, RerankEmbeddings):
self.api_url, json=data, timeout=self.timeout, headers=headers
)
response.raise_for_status()
return response.json()["data"]
return self._parse_results(response.json())
async def apredict(self, query: str, candidates: List[str]) -> List[float]:
"""Predict the rank scores of the candidates asynchronously."""
@ -142,6 +161,50 @@ class OpenAPIRerankEmbeddings(BaseModel, RerankEmbeddings):
async with session.post(self.api_url, json=data) as resp:
resp.raise_for_status()
response_data = await resp.json()
if "data" not in response_data:
raise RuntimeError(response_data["detail"])
return response_data["data"]
return self._parse_results(response_data)
class SiliconFlowRerankEmbeddings(OpenAPIRerankEmbeddings):
"""SiliconFlow Rerank Model.
See `SiliconFlow API
<https://docs.siliconflow.cn/api-reference/rerank/create-rerank>`_ for more details.
"""
def __init__(self, **kwargs: Any):
"""Initialize the SiliconFlowRerankEmbeddings."""
# If the API key is not provided, try to get it from the environment
if "api_key" not in kwargs:
kwargs["api_key"] = os.getenv("SILICON_FLOW_API_KEY")
if "api_url" not in kwargs:
env_api_url = os.getenv("SILICON_FLOW_API_BASE")
if env_api_url:
env_api_url = env_api_url.rstrip("/")
kwargs["api_url"] = env_api_url + "/rerank"
else:
kwargs["api_url"] = "https://api.siliconflow.cn/v1/rerank"
if "model_name" not in kwargs:
kwargs["model_name"] = "BAAI/bge-reranker-v2-m3"
super().__init__(**kwargs)
def _parse_results(self, response: Dict[str, Any]) -> List[float]:
"""Parse the response from the API.
Args:
response: The response from the API.
Returns:
List[float]: The rank scores of the candidates.
"""
results = response.get("results")
if not results:
raise RuntimeError("Cannot find results in the response")
if not isinstance(results, list):
raise RuntimeError("Results should be a list")
# Sort by index, 0 in the first element
results = sorted(results, key=lambda x: x.get("index", 0))
scores = [float(result.get("relevance_score")) for result in results]
return scores