mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-02 00:28:00 +00:00
feat(model): Support siliconflow rerank models (#2188)
This commit is contained in:
parent
23aedea092
commit
abab4e3e65
@ -322,6 +322,7 @@ EMBEDDING_MODEL_CONFIG = {
|
||||
"bge-reranker-large": os.path.join(MODEL_PATH, "bge-reranker-large"),
|
||||
# Proxy rerank model
|
||||
"rerank_proxy_http_openapi": "rerank_proxy_http_openapi",
|
||||
"rerank_proxy_silicon_flow": "rerank_proxy_silicon_flow",
|
||||
}
|
||||
|
||||
|
||||
|
@ -109,6 +109,18 @@ class EmbeddingLoader:
|
||||
if proxy_param.proxy_backend:
|
||||
openapi_param["model_name"] = proxy_param.proxy_backend
|
||||
return OpenAPIRerankEmbeddings(**openapi_param)
|
||||
elif model_name in ["rerank_proxy_silicon_flow"]:
|
||||
from dbgpt.rag.embedding.rerank import SiliconFlowRerankEmbeddings
|
||||
|
||||
proxy_param = cast(ProxyEmbeddingParameters, param)
|
||||
openapi_param = {}
|
||||
if proxy_param.proxy_server_url:
|
||||
openapi_param["api_url"] = proxy_param.proxy_server_url
|
||||
if proxy_param.proxy_api_key:
|
||||
openapi_param["api_key"] = proxy_param.proxy_api_key
|
||||
if proxy_param.proxy_backend:
|
||||
openapi_param["model_name"] = proxy_param.proxy_backend
|
||||
return SiliconFlowRerankEmbeddings(**openapi_param)
|
||||
else:
|
||||
from dbgpt.rag.embedding.rerank import CrossEncoderRerankEmbeddings
|
||||
|
||||
|
@ -613,7 +613,16 @@ class ProxyEmbeddingParameters(BaseEmbeddingModelParameters):
|
||||
|
||||
|
||||
_EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG = {
|
||||
ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi,proxy_ollama,proxy_tongyi,proxy_qianfan,rerank_proxy_http_openapi",
|
||||
ProxyEmbeddingParameters: [
|
||||
"proxy_openai",
|
||||
"proxy_azure",
|
||||
"proxy_http_openapi",
|
||||
"proxy_ollama",
|
||||
"proxy_tongyi",
|
||||
"proxy_qianfan",
|
||||
"rerank_proxy_http_openapi",
|
||||
"rerank_proxy_silicon_flow",
|
||||
]
|
||||
}
|
||||
|
||||
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG = {}
|
||||
@ -622,7 +631,6 @@ EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG = {}
|
||||
def _update_embedding_config():
|
||||
global EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG
|
||||
for param_cls, models in _EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG.items():
|
||||
models = [m.strip() for m in models.split(",")]
|
||||
for model in models:
|
||||
if model not in EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG:
|
||||
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG[model] = param_cls
|
||||
|
@ -17,7 +17,11 @@ from .embeddings import ( # noqa: F401
|
||||
QianFanEmbeddings,
|
||||
TongYiEmbeddings,
|
||||
)
|
||||
from .rerank import CrossEncoderRerankEmbeddings, OpenAPIRerankEmbeddings # noqa: F401
|
||||
from .rerank import ( # noqa: F401
|
||||
CrossEncoderRerankEmbeddings,
|
||||
OpenAPIRerankEmbeddings,
|
||||
SiliconFlowRerankEmbeddings,
|
||||
)
|
||||
|
||||
__ALL__ = [
|
||||
"CrossEncoderRerankEmbeddings",
|
||||
@ -32,6 +36,7 @@ __ALL__ = [
|
||||
"OllamaEmbeddings",
|
||||
"OpenAPIEmbeddings",
|
||||
"OpenAPIRerankEmbeddings",
|
||||
"SiliconFlowRerankEmbeddings",
|
||||
"QianFanEmbeddings",
|
||||
"TongYiEmbeddings",
|
||||
"WrappedEmbeddingFactory",
|
||||
|
@ -1,5 +1,6 @@
|
||||
"""Re-rank embeddings."""
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
import aiohttp
|
||||
@ -104,6 +105,24 @@ class OpenAPIRerankEmbeddings(BaseModel, RerankEmbeddings):
|
||||
kwargs["session"] = session
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _parse_results(self, response: Dict[str, Any]) -> List[float]:
|
||||
"""Parse the response from the API.
|
||||
|
||||
Args:
|
||||
response: The response from the API.
|
||||
|
||||
Returns:
|
||||
List[float]: The rank scores of the candidates.
|
||||
"""
|
||||
data = response.get("data")
|
||||
if not data:
|
||||
if "detail" in response:
|
||||
raise RuntimeError(response["detail"])
|
||||
raise RuntimeError("Cannot find results in the response")
|
||||
if not isinstance(data, list):
|
||||
raise RuntimeError("Results should be a list")
|
||||
return data
|
||||
|
||||
def predict(self, query: str, candidates: List[str]) -> List[float]:
|
||||
"""Predict the rank scores of the candidates.
|
||||
|
||||
@ -126,7 +145,7 @@ class OpenAPIRerankEmbeddings(BaseModel, RerankEmbeddings):
|
||||
self.api_url, json=data, timeout=self.timeout, headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()["data"]
|
||||
return self._parse_results(response.json())
|
||||
|
||||
async def apredict(self, query: str, candidates: List[str]) -> List[float]:
|
||||
"""Predict the rank scores of the candidates asynchronously."""
|
||||
@ -142,6 +161,50 @@ class OpenAPIRerankEmbeddings(BaseModel, RerankEmbeddings):
|
||||
async with session.post(self.api_url, json=data) as resp:
|
||||
resp.raise_for_status()
|
||||
response_data = await resp.json()
|
||||
if "data" not in response_data:
|
||||
raise RuntimeError(response_data["detail"])
|
||||
return response_data["data"]
|
||||
return self._parse_results(response_data)
|
||||
|
||||
|
||||
class SiliconFlowRerankEmbeddings(OpenAPIRerankEmbeddings):
|
||||
"""SiliconFlow Rerank Model.
|
||||
|
||||
See `SiliconFlow API
|
||||
<https://docs.siliconflow.cn/api-reference/rerank/create-rerank>`_ for more details.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
"""Initialize the SiliconFlowRerankEmbeddings."""
|
||||
# If the API key is not provided, try to get it from the environment
|
||||
if "api_key" not in kwargs:
|
||||
kwargs["api_key"] = os.getenv("SILICON_FLOW_API_KEY")
|
||||
|
||||
if "api_url" not in kwargs:
|
||||
env_api_url = os.getenv("SILICON_FLOW_API_BASE")
|
||||
if env_api_url:
|
||||
env_api_url = env_api_url.rstrip("/")
|
||||
kwargs["api_url"] = env_api_url + "/rerank"
|
||||
else:
|
||||
kwargs["api_url"] = "https://api.siliconflow.cn/v1/rerank"
|
||||
|
||||
if "model_name" not in kwargs:
|
||||
kwargs["model_name"] = "BAAI/bge-reranker-v2-m3"
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _parse_results(self, response: Dict[str, Any]) -> List[float]:
|
||||
"""Parse the response from the API.
|
||||
|
||||
Args:
|
||||
response: The response from the API.
|
||||
|
||||
Returns:
|
||||
List[float]: The rank scores of the candidates.
|
||||
"""
|
||||
results = response.get("results")
|
||||
if not results:
|
||||
raise RuntimeError("Cannot find results in the response")
|
||||
if not isinstance(results, list):
|
||||
raise RuntimeError("Results should be a list")
|
||||
# Sort by index, 0 in the first element
|
||||
results = sorted(results, key=lambda x: x.get("index", 0))
|
||||
scores = [float(result.get("relevance_score")) for result in results]
|
||||
return scores
|
||||
|
Loading…
Reference in New Issue
Block a user